+1 vote
2 views
in AI and Deep Learning by (3.5k points)
edited by

I am searching for the better option to save a trained model in PyTorch, these are the options I am using right now:

  1. to save model model.state_dict() and to load it model.load_state_dict()
  2. to save model torch.save() and to load torch.load()

I read it somewhere that approach 2 is better than 1.

Why second approach is preferred? Is the reason behind this because torch.nn modules haven the above two functions?

1 Answer

0 votes
by (10.9k points)

The pickle library implements serializing and de-serializing of Python objects.

The first way is to import torch which imports pickle and then you can call torch.save() or torch.load() which wraps the pickle.dump() and pickle.load() for you.Pickle.dump() and pickle.load() are the actual methods used to save and load an object.

Syntax of torch.save()-

torch.save(the_model.state_dict(), PATH)  

Second way,

The torch.nn. module has learnable parameters which are the first state_dict and the second state_dist is the optimizer state dict, the optimizer is used to improve the learnable parameters. Since state_dict objects are Python dictionaries, you can easily save, update, alter and restore, this is why it is preferred over torch.save().

import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
opti = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("weightof model:")    
print(model.weight)
print("bias:")    
print(model.bias)
print("---")
print("Optimizer's state dict:")
for variable_name in opti.state_dict():
print(variable_name, "\t", opti.state_dict()[variable_name])
...