Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make new optimizer more extensible, easier to integrate downstream for FSDP #181

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
15 changes: 7 additions & 8 deletions examples/mnist_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,23 +147,22 @@ def fsdp_main(rank, world_size, args):

model = Net().to(rank)

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

if args.msamp:
from msamp.fsdp import FsdpReplacer
import msamp
from msamp.fsdp import FP8FullyShardedDataParallel
model = FsdpReplacer.replace(model)
from msamp.optim import FSDPAdamW
from msamp.common.dtype import Dtypes
model, optimizer = msamp.initialize(model, optimizer, use_fsdp=True, weight_qtype=Dtypes.kfloat8_e4m3)
model = FP8FullyShardedDataParallel(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy)
optimizer = FSDPAdamW(optimizer)
else:
model = FSDP(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy)

if rank == 0:
print(f'FSDP model: {model}')

if args.msamp:
from msamp.optim import FSDPAdam
optimizer = FSDPAdam(model.parameters(), lr=args.lr)
else:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
init_start_event.record()
for epoch in range(1, args.epochs + 1):
Expand Down
44 changes: 40 additions & 4 deletions msamp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,23 @@
import torch
from deepspeed.ops.adam import FusedAdam

from msamp.common.dtype import Dtypes
from msamp.nn import clip_grad_norm_
from msamp.nn import LinearReplacer
from msamp.optim import LBAdam, LBAdamW, DSAdam
from msamp.optim import LBAdam, LBAdamW, DSAdam, FSDPAdamW
from msamp.te import TeReplacer

opt_levels = ['O1', 'O2']


def initialize(model, optimizer=None, opt_level='O1', use_te=False): # noqa: C901
def initialize(
model,
optimizer=None,
opt_level='O1',
use_te=False,
weight_qtype=Dtypes.kfloat16,
use_fsdp=False,
): # noqa: C901
"""Initialize your model, optimizer according to the optimization level.

msamp.initialize() should be called after you have finished constructing your model and optimizer.
Expand All @@ -28,7 +36,8 @@ def initialize(model, optimizer=None, opt_level='O1', use_te=False): # noqa:
'O1' || fp8 || fp8 || fp16 || fp8 || fp32 + FP32
'O2' || fp8 || fp8 || fp16 || fp8 || fp8 + fp16
use_te (bool): Whether to use Transformer Engine.

weight_qtype (Dtypes): Weight quantization type.
use_fsdp (bool): Whether to prepare the model for FSDP wrapping.
Return:
return the casted model and optimizer.
"""
Expand Down Expand Up @@ -60,9 +69,36 @@ def initialize(model, optimizer=None, opt_level='O1', use_te=False): # noqa:
index = param_index_map[id(param)]
index_list.append(index)
if not use_te:
cast_model = LinearReplacer.replace(model)
cast_model = LinearReplacer.replace(model, weight_qtype=weight_qtype)
else:
cast_model = TeReplacer.replace(model)

if use_fsdp:
# When using FSDP, the named parameters of the model are different now and we need to adjust the param groups
old_named_params = {n: p for n, p in cast_model.named_parameters()}
for _, submodule in cast_model.named_modules():
params_to_process = list(submodule.named_parameters(recurse=False))
for param_name, param in params_to_process:
if not isinstance(param, torch.Tensor):
data = param.value.view(-1)
padded = 0
if data.numel() % 4 != 0:
padded = 4 - data.numel() % 4
data = torch.nn.functional.pad(data, (0, padded))

data = data.view(dtype=torch.float32)
new_param = torch.nn.Parameter(data)
new_param._original_shape = param.shape
new_param._padded = padded
new_param._meta = param.meta
new_param._scaling_metas = param._scaling_metas

setattr(submodule, param_name, new_param)
# Map our new parameters to the optimizer param groups
new_named_params = {n: p for n, p in cast_model.named_parameters()}
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
for param_group in optimizer.param_groups:
param_group["params"] = [mapping.get(p, p) for p in param_group["params"]]

parameters = list(cast_model.parameters())

Expand Down
37 changes: 7 additions & 30 deletions msamp/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.distributed as dist

from msamp.optim import LBAdamWBase
from msamp.optim.optimizer import MSAMPOptimWrapper
from msamp.common.tensor import ScalingMeta, ScalingTensor
from msamp.common.dtype import Floating, Dtypes
import msamp_adamw
Expand Down Expand Up @@ -235,41 +236,16 @@ def adamw_fn( # noqa: C901
params[i].copy_(param.cast(params[i].qtype, meta=params[i].meta))


class FSDPAdamW(LBAdamWBase):
class FSDPAdamW(MSAMPOptimWrapper):
"""Implements AdamW algorithm for FSDP."""
def __init__(
self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
*,
maximize: bool = False,
exp_avg_dtype=torch.uint8,
exp_avg_sq_dtype=torch.float16,
tensor_scale=True,
optimizer=None,
):
"""Constructor. See LBAdamW class docstring for details."""
self.tensor_scale = tensor_scale
super().__init__(
params,
lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=False,
maximize=maximize,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype
)

super().__init__(optimizer)
self.original_params = []
self.master_weights = []

for group in self.param_groups:
params = []
for param in group['params']:
Expand All @@ -290,6 +266,7 @@ def __init__(

group['params'] = params


def zero_grad(self, set_to_none=False):
"""Zero gradients."""
for param in self.original_params:
Expand All @@ -303,7 +280,7 @@ def zero_grad(self, set_to_none=False):
param.grad.requires_grad_(False)
param.grad.zero_()

def step(self):
def step(self, closure=None):
"""Performs a single optimization step."""
# Set gradient of master weight.
for i, param in enumerate(self.original_params):
Expand All @@ -314,7 +291,7 @@ def step(self):
param.grad = None

# call step() to update master weight
super().step()
super().step(closure)

# Copy master weight to weight
for i, param in enumerate(self.original_params):
Expand Down
45 changes: 45 additions & 0 deletions msamp/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,51 @@
from msamp.nn import model_state, ScalingParameter


class MSAMPOptimWrapper(Optimizer):
"""
A wrapper around an optimizer for easier extensibility.
All methods are delegated to the underlying optimizer,
so that custom functionality can be added by subclassing this class.
"""
def __init__(self, optimizer):
self.optimizer = optimizer

@property
def state(self):
return self.optimizer.state

@state.setter
def state(self, state):
self.optimizer.state = state

@property
def param_groups(self):
return self.optimizer.param_groups

@property
def defaults(self):
return self.optimizer.defaults

@defaults.setter
def defaults(self, defaults):
self.optimizer.defaults = defaults

def add_param_group(self, param_group):
self.optimizer.add_param_group(param_group)

def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)

def state_dict(self):
return self.optimizer.state_dict()

def zero_grad(self, set_to_none=None):
self.optimizer.zero_grad(set_to_none)

def step(self, closure=None):
return self.optimizer.step(closure)


class LBOptimizer(Optimizer):
"""Low-bit optimizer base class.

Expand Down