What does model.train() do in PyTorch?

model.train() tells your model that you are training the model. This helps inform layers such as Dropout and BatchNorm, which are designed to behave differently during training and evaluation. For instance, in training mode, BatchNorm updates a moving average on each new batch; whereas, for evaluation mode, these updates are frozen.

More details: model.train() sets the mode to train (see source code). You can call either model.eval() or model.train(mode=False) to tell that you are testing. It is somewhat intuitive to expect train function to train model but it does not do that. It just sets the mode.


Here is the code for nn.Module.train():

def train(self, mode=True):
        r"""Sets the module in training mode."""      
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

Here is the code for nn.Module.eval():

def eval(self):
        r"""Sets the module in evaluation mode."""
        return self.train(False)

By default, the self.training flag is set to True, i.e., modules are in train mode by default. When self.training is False, the module is in the opposite state, eval mode.

Of the most commonly used layers, only Dropout and BatchNorm care about that flag.

Tags:

Python

Pytorch