The pickle library implements serializing and de-serializing of Python objects.
The first way is to import torch which imports pickle and then you can call torch.save() or torch.load() which wraps the pickle.dump() and pickle.load() for you.Pickle.dump() and pickle.load() are the actual methods used to save and load an object.
Syntax of torch.save()-
torch.save(the_model.state_dict(), PATH)
Second way,
The torch.nn. module has learnable parameters which are the first state_dict and the second state_dist is the optimizer state dict, the optimizer is used to improve the learnable parameters. Since state_dict objects are Python dictionaries, you can easily save, update, alter and restore, this is why it is preferred over torch.save().
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
opti = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("weightof model:")
print(model.weight)
print("bias:")
print(model.bias)
print("---")
print("Optimizer's state dict:")
for variable_name in opti.state_dict():
print(variable_name, "\t", opti.state_dict()[variable_name])