The Python Oracle

What does model.train() do in PyTorch?

--------------------------------------------------
Hire the world's top talent on demand or became one of them at Toptal: https://topt.al/25cXVn
and get $2,000 discount on your first invoice
--------------------------------------------------

Music by Eric Matyas
https://www.soundimage.org
Track title: Over a Mysterious Island Looping

--

Chapters
00:00 What Does Model.Train() Do In Pytorch?
00:18 Accepted Answer Score 323
01:03 Answer 2 Score 17
01:32 Answer 3 Score 100
02:08 Answer 4 Score 43
02:29 Thank you

--

Full question
https://stackoverflow.com/questions/5143...

--

Content licensed under CC BY-SA
https://meta.stackexchange.com/help/lice...

--

Tags
#python #pytorch

#avk47



ACCEPTED ANSWER

Score 323


model.train() tells your model that you are training the model. This helps inform layers such as Dropout and BatchNorm, which are designed to behave differently during training and evaluation. For instance, in training mode, BatchNorm updates a moving average on each new batch; whereas, for evaluation mode, these updates are frozen.

More details: model.train() sets the mode to train (see source code). You can call either model.eval() or model.train(mode=False) to tell that you are testing. It is somewhat intuitive to expect train function to train model but it does not do that. It just sets the mode.




ANSWER 2

Score 100


Here is the code for nn.Module.train():

def train(self, mode=True):
        r"""Sets the module in training mode."""      
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

Here is the code for nn.Module.eval():

def eval(self):
        r"""Sets the module in evaluation mode."""
        return self.train(False)

By default, the self.training flag is set to True, i.e., modules are in train mode by default. When self.training is False, the module is in the opposite state, eval mode.

Of the most commonly used layers, only Dropout and BatchNorm care about that flag.




ANSWER 3

Score 43


model.train() model.eval()
Sets model in training mode i.e.

BatchNorm layers use per-batch statistics
Dropout layers activated etc
Sets model in evaluation (inference) mode i.e.

BatchNorm layers use running statistics
Dropout layers de-activated etc
Equivalent to model.train(False).

Note: neither of these function calls run forward / backward passes. They tell the model how to act when run.

This is important as some modules (layers) (e.g. Dropout, BatchNorm) are designed to behave differently during training vs inference, and hence the model will produce unexpected results if run in the wrong mode.




ANSWER 4

Score 17


There are two ways of letting the model know your intention i.e do you want to train the model or do you want to use the model to evaluate. In case of model.train() the model knows it has to learn the layers and when we use model.eval() it indicates the model that nothing new is to be learnt and the model is used for testing. model.eval() is also necessary because in pytorch if we are using batchnorm and during test if we want to just pass a single image, pytorch throws an error if model.eval() is not specified.