You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In reversible mode, one can further save memory by computing backward in chunks:
a few tokens at a time for feedforward layers, since grad(concat(mlp(x1), mlp(x2))) = concat(grad(mlp(x1)), grad(mlp(x2)))
a few queries at a time for self-attention, since grad(head1 + head2) = grad(head1) + grad(head2), where head1 and head2 are attention outputs after linear projection
improved checkpointing
allow user to specify the number of checkpoints, as in checkpoint_sequential
do not rematerialize the last layer, as in checkpoint_sequential
optionally cast checkpoints to a lower precision, as in revlib
compacted params
compacted layernorms, biases
compacted adapters
Attention could be computed in O(sqrt(n)) memory (Rabe et al, 2021), but this may be overkill
sparse or linear attention: they are great for very long sequences. However, for large models, attention is not a bottleneck in typical NLP and vision tasks (tested gpt-3 up to length 4096).
Per-block grad scaling as described in (Ramesh et al, 2021) - we rely on Sandwich Norm to maintain stability up to 96 layers (did not test more). However, it would be nice to
have per-block scaling to avoid the need for an extra LayerNorm.
Something else that we missed - please find us on discord.
The text was updated successfully, but these errors were encountered:
grad(concat(mlp(x1), mlp(x2))) = concat(grad(mlp(x1)), grad(mlp(x2)))
grad(head1 + head2) = grad(head1) + grad(head2)
, where head1 and head2 are attention outputs after linear projectionO(sqrt(n))
memory (Rabe et al, 2021), but this may be overkillhave per-block scaling to avoid the need for an extra LayerNorm.
The text was updated successfully, but these errors were encountered: