[PROPOSAL] Lazy initialization of model #3124
ver217
started this conversation in
Development | Core
Replies: 1 comment 1 reply
-
How to load pretrained weight in this context? Could you please provide an example? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
What are LazyTensor and LazyInit
LazyTensor allows DL framework (PyTorch) to execute operations lazily, by storing all operations related to it and reruning them when it's required to be materialized.
LazyInit defers model initialiazation and it's based on LazyTensor.
Why we need implement new LazyInit
ColossalAI actually has similar feature now: ColoInitContext. It hijacks the
__init__()
method of each module and shard each layer with preset sharding strategy.This can work well when sharding strategy is known before model initialization. If sharding strategy is generated from static analysis about model, this method won't work anymore.
So we need initialize model tensors using meta tensor and do static analysis to get sharding strategy. And then materialize each tensor and apply the sharding strategy. The static analysis can be omitted if the sharding strategy is known in advance.
A possible initialization process:
Actually torchdistx has implemented similar features named
FakeTensor
anddeferred_init
.Why do we need implement a new one?
DTensor
. Buttorchdistx
is a kind of blackbox to us.Method
We have experimental code about lazy init. Thank @super-dainiu .
We can start this work by this file. We implement a LazyTensor class which tracks all OPs and rerun them when materializing.
A possible class definition of LazyTensor:
A possible usage of LazyInit:
LazyInitContext.materialize()
andLazyInitContext.distribute()
are static methods and may be replaced with downstream model wrappers.Limitations
To keep the implementation simple, we have some trade-off when designing.
We cannot ensure the lazy intialized model is the same as standard intialized model, but we can ensure its parameters are initialized from the same distribution.
There are some cases we cannot ensure the correctness:
As we don't track tensor's slice relationship, there are also some cases lazy execution won't work and tensors may be early materialized.
How to verify
To verify the correctness, we have to control the random seed. We can implement a utility tensor class:
Thus, random states are same before each OP.
For
LazyTensor
, we reserve a hook to control random seed before executing each OP.By doing this, we can simply compare the state dict or forward result to verify the correctness of initialization.
Possible Roadmap
Beta Was this translation helpful? Give feedback.
All reactions