Here, the format of dataset is (Height, Width, Channel) and the format which the model is expecting is (Channel, Height, Width).
So, using the tf.transpose() function we can switch around the dimensions of a tensor.
For converting a single image tensor from HWC to CHW:
reshaped = tf.transpose(image_tensor, (2,0,1))
And for converting a batch:
reshaped = tf.transpose(images_tensor, (0,3,1,2))