Saving PyTorch model with no access to model class code

If you plan to do inference with the Pytorch library available (i.e. Pytorch in Python, C++, or other platforms it supports) then the best way to do this is via TorchScript.

I think the simplest thing is to use trace = torch.jit.trace(model, typical_input) and then torch.jit.save(trace, path). You can then load the traced model with torch.jit.load(path).

Here's a really simple example. We make two files:

train.py :

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = torch.relu(self.linear(x))
        return x

model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
    print(model(x))
    traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")

infer.py :

import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
    print(loaded_trace(x))

Running these sequentially gives results:

python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])

python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
        [0.0000, 0.5272, 0.3481, 0.1743]])

The results are the same, so we are good. (Note that the result will be different each time here due to randomness of the initialisation of the nn.Linear layer).

TorchScript provides for much more complex architectures and graph definitions (including if statements, while loops, and more) to be saved in a single file, without needing to redefine the graph at inference time. See the docs (linked above) for more advanced possibilities.


I recomend you to convert you pytorch model to onnx and save it. Probably its best way to store model without an access to the class.


Supplying an official answer by one of the core PyTorch devs (smth):

There are limitations to loading a pytorch model without code.

First limitation: We only save the source code of the class definition. We do not save beyond that (like the package sources that the class is referring to).

For example:

import foo

class MyModel(...):
    def forward(input):
        foo.bar(input)

Here the package foo is not saved in the model checkpoint.

Second limitation: There are limitations on robustly serializing python constructs. For example the default picklers cannot serialize lambdas. There are helper packages that can serialize more python constructs than the standard, but they still have limitations. Dill 25 is one such package.

Given these limitations, there is no robust way to have torch.load work without having the original source files.