Is there any way I can download the pre-trained models available in PyTorch to a specific path?

As, @dennlinger mentioned in his answer : torch.utils.model_zoo, is being internally called when you load a pre-trained model.

More specifically, the method: torch.utils.model_zoo.load_url() is being called every time a pre-trained model is loaded. The documentation for the same, mentions:

The default value of model_dir is $TORCH_HOME/models where $TORCH_HOME defaults to ~/.torch.

The default directory can be overridden with the $TORCH_HOME environment variable.

This can be done as follows:

import torch 
import torchvision
import os

# Suppose you are trying to load pre-trained resnet model in directory- models\resnet

os.environ['TORCH_HOME'] = 'models\\resnet' #setting the environment variable
resnet = torchvision.models.resnet18(pretrained=True)

I came across the above solution by raising an issue in the PyTorch's GitHub repository: https://github.com/pytorch/vision/issues/616

This led to an improvement in the documentation i.e. the solution mentioned above.


Yes, you can simply copy the urls and use wget to download it to the desired path. Here's an illustration:

For AlexNet:

$ wget -c https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth

For Google Inception (v3):

$ wget -c https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth

For SqueezeNet:

$ wget -c https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth

For MobileNetV2:

$ wget -c https://download.pytorch.org/models/mobilenet_v2-b0353104.pth

For DenseNet201:

$ wget -c https://download.pytorch.org/models/densenet201-c1103571.pth

For MNASNet1_0:

$ wget -c https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth

For ShuffleNetv2_x1.0:

$ wget -c https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth

If you want to do it in Python, then use something like:

In [11]: from six.moves import urllib

# resnet 101 host url
In [12]: url = "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"

# download and rename the file to `resnet_101.pth`
In [13]: urllib.request.urlretrieve(url, "resnet_101.pth")
Out[13]: ('resnet_101.pth', <http.client.HTTPMessage at 0x7f7fd7f53438>)

P.S: You can find the download URLs in the respective python modules of torchvision.models