What are transforms in PyTorch used for?

For ambiguities about data augmentation, I would refer you to this answer:

Data Augmentation in PyTorch

But in short, assume you only have random horizontal flipping transform, when you iterate through a dataset of images, some are returned as original and some are returned as flipped(The original images for the flipped ones are not returned). In other words, the number of returned images in one iteration is the same as the original size of the dataset and is not augmented.


transforms.Compose just clubs all the transforms provided to it. So, all the transforms in the transforms.Compose are applied to the input one by one.

Train transforms

  1. transforms.RandomResizedCrop(224): This will extract a patch of size (224, 224) from your input image randomly. So, it might pick this path from topleft, bottomright or anywhere in between. So, you are doing data augmentation in this part. Also, changing this value won't play nice with the fully-connected layers in your model, so not advised to change this.
  2. transforms.RandomHorizontalFlip(): Once we have our image of size (224, 224), we can choose to flip it. This is another part of data augmentation.
  3. transforms.ToTensor(): This just converts your input image to PyTorch tensor.
  4. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]): This is just input data scaling and these values (mean and std) must have been precomputed for your dataset. Changing these values is also not advised.

Validation transforms

  1. transforms.Resize(256): First your input image is resized to be of size (256, 256)
  2. transforms.CentreCrop(224): Crops the center part of the image of shape (224, 224)

Rest are the same as train

P.S.: You can read more about these transformations in the official docs