-
Notifications
You must be signed in to change notification settings - Fork 186
Sharing Model Data between PyTorch and TorchSharp
There are typically two kinds of state in a model – parameters, which contain trained weights, and buffers, which contain data that is not trained, but still essential for the functioning of the model. Both should generally be saved and loaded when serializing models. There is also optimizer state, which isn’t part of the model itself, but rather training data. It can be useful to save optimizer state to be able to resume a training loop if it crashes or runs out of time.
When using PyTorch, the expected pattern to use when saving and later restoring models from disk or other permanent storage media, is to get the model’s state and pickle that using the standard Python format, which is what torch.save() does.
torch.save(model.state_dict(), 'model_weights.pth')
When restoring the model, you are expected to first create a model of the exact same structure as the original, with random weights, then restore the state from a unpickled object:
model = [...]
model.load_state_dict(torch.load('model_weights.pth'))
This presents a couple of problems for a .NET implementation.
Python pickling is intimately coupled to Python and its runtime object model. It is a complex format that supports object graphs forming DAGs, faithfully maintaining all object state in the way necessary to restore the Python object later.
We are unaware of a .NET library that faithfully restores pickled Python classes in .NET; those unpickling libraries that are known to the maintainers of TorchSharp have limitations that make them unsuitable for our needs.
Therefore, TorchSharp, in its current form, implements its own very simple model serialization format, which allows models originating in either .NET or Python to be loaded using .NET, as long as the model was saved using the special format.
The MNIST and AdversarialExampleGeneration examples in this repo rely on saving and restoring model state – the latter example relies on a pre-trained model from MNST.
In C#, saving a model looks like this:
model.save("model_weights.dat");
And loading it again is done by:
model = [...];
model.load("model_weights.dat");
For efficient memory management, the model should be created on the CPU before loading weights, then moved to the target device.
It is __critical__ that all submodules and buffers in a custom module or composed by a Sequential object have exactly the same name in the original and target models, since that is how persisted tensors are associated with the model into which they are loaded.
The CustomModule ‘RegisterComponents’ will automatically find all fields that are either modules or tensors, register the former as modules, and the latter as buffers. It registers all of these using the name of the field, just like the PyTorch Module base class does.
If the model starts out in Python, there’s a simple script that allows you to use code that is very similar to the Pytorch API to save models to the TorchSharp format. Rather than placing this trivial script in a Python package and publishing it, we choose to just refer you to the script file itself, [exportsd.py](../blob/main/src/Python/exportsd.py), which has all the necessary code.
f = open("model_weights.dat", "wb")
exportsd.save_state_dict(model.to("cpu").state_dict(), f)
f.close()
If the model starts out in TorchSharp, there’s also a simple script that allows you to load TorchSharp models in Python. All the necessary code can be found in [importsd.py](../blob/main/src/Python/importsd.py). And there is an example for using the script:
f = open("model_weights.dat", "rb")
model.load_state_dict(importsd.load_state_dict(f))
f.close()
Also, you can check [TestSaveSD.cs](../blob/main/test/TorchSharpTest/TestSaveSD.cs) and [pyimporttest.py](../blob/main/test/TorchSharpTest/pyimporttest.py) for more examples.
For those seeking additional flexibility, especially in a mixed .NET and Python environment, TorchSharp.PyBridge offers an alternative approach. Developed by Shaltiel Shmidman, this extension library facilitates seamless interoperability between .NET and Python for model serialization, simplifying the process of saving and loading PyTorch models in a .NET environment.
Key features include:
- `load_py` Method: Easily load PyTorch models saved in the standard Python format directly into TorchSharp.
- `save_py` Method: Save TorchSharp models in a format that can be directly loaded in PyTorch, offering cross-platform model compatibility.
Please note that TorchSharp.PyBridge is not maintained by the TorchSharp team and is an independent extension package. For detailed usage instructions, limitations, and more information, visit [TorchSharp.PyBridge on GitHub](https://github.com/shaltielshmid/TorchSharp.PyBridge).
Starting with release 0.96.9, you can load TorchScript modules and functions that have been either traced or scripted in Pytorch. It is, however, not yet possible to create a TorchScript module from scratch using TorchSharp.
The use of TorchScript is described in a separate article: [TorchScript](./TorchScript)
Starting with version 0.99.4, it is possible to export optimizer state in PyTorch and load it in TorchSharp (but not the other way around). In order to make this work, you need to take great care to make sure that the optimizer loading the state in TorchSharp is created the exact same way as the original optimizer in PyTorch.
This means that the same parameters are passed to the constructor, in the same order, that all parameter groups are created in the same order, and with the exact same parameters. Anything else will result in a runtime error (at best) or corrupt data (at worst).
The ‘exportsd.py’ file used to export model data, i.e. weights and buffers, is also used to export optimizer state. Each supported optimizer has a `save_xxx(optimizer,stream)` function that takes a specific optimizer kind and saves its state.
These functions are:
save_sgd(optim, stream)
save_adadelta(optim, stream)
save_adagrad(optim, stream)
save_adam(optim, stream)
save_adamax(optim, stream)
save_adamw(optim, stream)
save_asgd(optim, stream)
save_nadam(optim, stream)
save_radam(optim, stream)
save_rprop(optim, stream)
save_rmsprop(optim, stream)
An example of how to use these:
optim = torch.optim.Adam(lin1.parameters(), lr=0.001, betas=(0.8, 0.9))
optim.add_param_group({'params': lin2.parameters(), 'lr': 0.01, 'betas' : (0.7, 0.79), 'amsgrad': True})
[...]
f = open("adam_state.dat", "wb")
exportsd.save_adam(optim, f)
f.close()
var pgs = new Adam.ParamGroup[] {new (lin1.parameters()), new (lin2.parameters())};
var optimizer = torch.optim.Adam(pgs, 0.00004f);
optimizer.load_state_dict("adam_state.dat");
Note that there is no support for saving or loading state for LBFGS optimizers, which are implemented entirely in native code. It is possible that a future managed code implementation of LBFGS will be done, in which case we will add support for state sharing.