diff --git a/examples/mnist_fsdp.py b/examples/mnist_fsdp.py index 5faed7b5..05efbc13 100644 --- a/examples/mnist_fsdp.py +++ b/examples/mnist_fsdp.py @@ -147,23 +147,22 @@ def fsdp_main(rank, world_size, args): model = Net().to(rank) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + if args.msamp: - from msamp.fsdp import FsdpReplacer + import msamp from msamp.fsdp import FP8FullyShardedDataParallel - model = FsdpReplacer.replace(model) + from msamp.optim import FSDPAdamW + from msamp.common.dtype import Dtypes + model, optimizer = msamp.initialize(model, optimizer, use_fsdp=True, weight_qtype=Dtypes.kfloat8_e4m3) model = FP8FullyShardedDataParallel(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy) + optimizer = FSDPAdamW(optimizer) else: model = FSDP(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy) if rank == 0: print(f'FSDP model: {model}') - if args.msamp: - from msamp.optim import FSDPAdam - optimizer = FSDPAdam(model.parameters(), lr=args.lr) - else: - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) init_start_event.record() for epoch in range(1, args.epochs + 1): diff --git a/msamp/__init__.py b/msamp/__init__.py index c0841e1c..87bd2a26 100644 --- a/msamp/__init__.py +++ b/msamp/__init__.py @@ -6,15 +6,23 @@ import torch from deepspeed.ops.adam import FusedAdam +from msamp.common.dtype import Dtypes from msamp.nn import clip_grad_norm_ from msamp.nn import LinearReplacer -from msamp.optim import LBAdam, LBAdamW, DSAdam +from msamp.optim import LBAdam, LBAdamW, DSAdam, FSDPAdamW from msamp.te import TeReplacer opt_levels = ['O1', 'O2'] -def initialize(model, optimizer=None, opt_level='O1', use_te=False): # noqa: C901 +def initialize( + model, + optimizer=None, + opt_level='O1', + use_te=False, + weight_qtype=Dtypes.kfloat16, + use_fsdp=False, +): # noqa: C901 """Initialize your model, optimizer according to the optimization level. msamp.initialize() should be called after you have finished constructing your model and optimizer. @@ -28,7 +36,8 @@ def initialize(model, optimizer=None, opt_level='O1', use_te=False): # noqa: 'O1' || fp8 || fp8 || fp16 || fp8 || fp32 + FP32 'O2' || fp8 || fp8 || fp16 || fp8 || fp8 + fp16 use_te (bool): Whether to use Transformer Engine. - + weight_qtype (Dtypes): Weight quantization type. + use_fsdp (bool): Whether to prepare the model for FSDP wrapping. Return: return the casted model and optimizer. """ @@ -60,9 +69,36 @@ def initialize(model, optimizer=None, opt_level='O1', use_te=False): # noqa: index = param_index_map[id(param)] index_list.append(index) if not use_te: - cast_model = LinearReplacer.replace(model) + cast_model = LinearReplacer.replace(model, weight_qtype=weight_qtype) else: cast_model = TeReplacer.replace(model) + + if use_fsdp: + # When using FSDP, the named parameters of the model are different now and we need to adjust the param groups + old_named_params = {n: p for n, p in cast_model.named_parameters()} + for _, submodule in cast_model.named_modules(): + params_to_process = list(submodule.named_parameters(recurse=False)) + for param_name, param in params_to_process: + if not isinstance(param, torch.Tensor): + data = param.value.view(-1) + padded = 0 + if data.numel() % 4 != 0: + padded = 4 - data.numel() % 4 + data = torch.nn.functional.pad(data, (0, padded)) + + data = data.view(dtype=torch.float32) + new_param = torch.nn.Parameter(data) + new_param._original_shape = param.shape + new_param._padded = padded + new_param._meta = param.meta + new_param._scaling_metas = param._scaling_metas + + setattr(submodule, param_name, new_param) + # Map our new parameters to the optimizer param groups + new_named_params = {n: p for n, p in cast_model.named_parameters()} + mapping = {p: new_named_params[n] for n, p in old_named_params.items()} + for param_group in optimizer.param_groups: + param_group["params"] = [mapping.get(p, p) for p in param_group["params"]] parameters = list(cast_model.parameters()) diff --git a/msamp/optim/adamw.py b/msamp/optim/adamw.py index 8cf68e10..9f36ea81 100644 --- a/msamp/optim/adamw.py +++ b/msamp/optim/adamw.py @@ -11,6 +11,7 @@ import torch.distributed as dist from msamp.optim import LBAdamWBase +from msamp.optim.optimizer import MSAMPOptimWrapper from msamp.common.tensor import ScalingMeta, ScalingTensor from msamp.common.dtype import Floating, Dtypes import msamp_adamw @@ -235,41 +236,16 @@ def adamw_fn( # noqa: C901 params[i].copy_(param.cast(params[i].qtype, meta=params[i].meta)) -class FSDPAdamW(LBAdamWBase): +class FSDPAdamW(MSAMPOptimWrapper): """Implements AdamW algorithm for FSDP.""" def __init__( self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - *, - maximize: bool = False, - exp_avg_dtype=torch.uint8, - exp_avg_sq_dtype=torch.float16, - tensor_scale=True, + optimizer=None, ): """Constructor. See LBAdamW class docstring for details.""" - self.tensor_scale = tensor_scale - super().__init__( - params, - lr=lr, - bias_correction=bias_correction, - betas=betas, - eps=eps, - weight_decay=weight_decay, - amsgrad=False, - maximize=maximize, - exp_avg_dtype=exp_avg_dtype, - exp_avg_sq_dtype=exp_avg_sq_dtype - ) - + super().__init__(optimizer) self.original_params = [] self.master_weights = [] - for group in self.param_groups: params = [] for param in group['params']: @@ -290,6 +266,7 @@ def __init__( group['params'] = params + def zero_grad(self, set_to_none=False): """Zero gradients.""" for param in self.original_params: @@ -303,7 +280,7 @@ def zero_grad(self, set_to_none=False): param.grad.requires_grad_(False) param.grad.zero_() - def step(self): + def step(self, closure=None): """Performs a single optimization step.""" # Set gradient of master weight. for i, param in enumerate(self.original_params): @@ -314,7 +291,7 @@ def step(self): param.grad = None # call step() to update master weight - super().step() + super().step(closure) # Copy master weight to weight for i, param in enumerate(self.original_params): diff --git a/msamp/optim/optimizer.py b/msamp/optim/optimizer.py index 640461fd..984c9feb 100644 --- a/msamp/optim/optimizer.py +++ b/msamp/optim/optimizer.py @@ -17,6 +17,51 @@ from msamp.nn import model_state, ScalingParameter +class MSAMPOptimWrapper(Optimizer): + """ + A wrapper around an optimizer for easier extensibility. + All methods are delegated to the underlying optimizer, + so that custom functionality can be added by subclassing this class. + """ + def __init__(self, optimizer): + self.optimizer = optimizer + + @property + def state(self): + return self.optimizer.state + + @state.setter + def state(self, state): + self.optimizer.state = state + + @property + def param_groups(self): + return self.optimizer.param_groups + + @property + def defaults(self): + return self.optimizer.defaults + + @defaults.setter + def defaults(self, defaults): + self.optimizer.defaults = defaults + + def add_param_group(self, param_group): + self.optimizer.add_param_group(param_group) + + def load_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict) + + def state_dict(self): + return self.optimizer.state_dict() + + def zero_grad(self, set_to_none=None): + self.optimizer.zero_grad(set_to_none) + + def step(self, closure=None): + return self.optimizer.step(closure) + + class LBOptimizer(Optimizer): """Low-bit optimizer base class.