Unlock the Power of Artificial Intelligence, Machine Learning, and Data Science with our Blog
Discover the latest insights, trends, and innovations in Artificial Intelligence (AI), Machine Learning (ML), and Data Science through our informative and engaging Hubspot blog. Gain a deep understanding of how these transformative technologies are shaping industries and revolutionizing the way we work.
Stay updated with cutting-edge advancements, practical applications, and real-world use.
Friday, 23 February 2024
Save and Load Your PyTorch Models
A deep learning model is a mathematical abstraction of data, in which a lot of parameters are involved. Training these parameters can take hours, days, and even weeks but afterward, you can make use of the result to apply on new data. This is called inference in machine learning. It is important to know how we can preserve the trained model in disk and later, load it for use in inference. In this post, you will discover how to save your PyTorch models to files and load them up again to make predictions. After reading this chapter, you will know:
What are states and parameters in a PyTorch model
How to save model states
How to load model states
Overview
This post is in three parts; they are
Build an Example Model
What’s Inside a PyTorch Model
Accessing state_dict of a Model
Build an Example Model
Let’s start with a very simple model in PyTorch. It is a model based on the iris dataset. You will load the dataset using scikit-learn (which the targets are integer labels 0, 1, and 2) and train a neural network for this multiclass classification problem. In this model, you used log softmax as the output activation so you can combine with the negative log likelihood loss function. It is equivalent to no output activation combined with cross entropy loss function.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn asnn
import torch.optim asoptim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
With such a simple model and small dataset, it shouldn’t take a long time to finish training. Afterwards, we can confirm that this model works, by evaluating it with the test set:
Want to Get Started With Deep Learning with PyTorch?
Take my free email crash course now (with sample code).
Click to sign-up and also get a free PDF Ebook version of the course.
What’s Inside a PyTorch Model
PyTorch model is an object in Python. It holds some deep learning building blocks such as various kinds of layers and activation functions. It also knows how to connect them so it can produce you an output from your input tensors. The algorithm of a model is fixed at the time you created it, however, it has trainable parameters that is supposed to be modified during training loop so the model can be more accurate.
You saw how to get the model parameters when you set up the optimizer for your training loop, namely,
1
optimizer=optim.Adam(model.parameters(),lr=0.001)
The function model.parameters() give you a generator that reference to each layers’ trainable parameters in turn in the form of PyTorch tensors. Therefore, it is possible for you to make a copy of them or overwrite them, for example:
1
2
3
4
5
6
7
8
9
10
# create a new model
newmodel=Multiclass()
# ask PyTorch to ignore autograd on update and overwrite parameters
Which the result should be exactly the same as before since you essentially made the two models identical by copying the parameters.
However, this is not always the case. Some models has non-trainable parameters. One example is the batch normalization layer that is common in many convolution neural networks. What it does is to apply normalization on tensors that produced by its previous layer and pass on the normalized tensor to its next layer. It has two parameters: The mean and standard deviation, which are learned from your input data during training loop but not trainable by the optimizer. Therefore these are not part of model.parameters() but equally important.
Accessing state_dict of a Model
To access all parameters of a model, trainable or not, you can get it from state_dict() function. From the model above, this is what you can get:
It is called state_dict because all state variables of a model are here. It is an OrderedDict object from Python’s built-in collections module. All components from a PyTorch model has a name and so as the parameters therein. The OrderedDict object allows you to map the weights back to the parameters correctly by matching their names.
This is how you should save and load the model: Fetch the model states into an OrderedDict, serialize and save it to disk. For inference, you create a model first (without training), and load the states. In Python, the native format for serialization is pickle:
The *.pth file is indeed a zip file of some pickle files created by PyTorch. It is recommended because PyTorch can store additional information in it. Note that you stored only the states but not the model. You still need to create the model using Python code and load the states into it. If you wish to store the model as well, you can pass in the entire model instead of the states:
But remember, due to the nature of Python language, doing so does not relieve you from keeping the code of the model. The newmodel object above is an instance of Multiclass class that you defined before. When you load the model from disk, Python need to know in detail how this class is defined. If you run a script with just the line torch.load(), you will see the following error message:
No comments:
Post a Comment