Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 1, 2023
1 parent 7f07eaf commit 7675d05
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions pipegoose/nn/tensor_parallel/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def deparallelize(self):

def _parallelize_column_linear(self, module: nn.Module) -> nn.Module:
module.__class__ = ColumnParallelLinear
module = self._assign_partition(module, dim=0)
module = self._slice_weight_and_bias(module, slice_bias=True, dim=0)

_update_model_arguments(
module=module,
Expand All @@ -59,7 +59,9 @@ def _parallelize_column_linear(self, module: nn.Module) -> nn.Module:

def _parallelize_row_linear(self, module: nn.Module) -> nn.Module:
module.__class__ = RowParallelLinear
module = self._assign_partition(module, dim=1)
# NOTE: It appears that row column doesn't not require splitting the bias,
# as the final output without splitting is correct
module = self._slice_weight_and_bias(module, slice_bias=False, dim=1)

_update_model_arguments(
module=module,
Expand All @@ -68,13 +70,10 @@ def _parallelize_row_linear(self, module: nn.Module) -> nn.Module:

return module

def _assign_partition(self, module: nn.Module, dim: int) -> nn.Module:
def _slice_weight_and_bias(self, module: nn.Module, slice_bias: bool, dim: int) -> nn.Module:
module.weight.data = get_partition(module.weight, self.parallel_context, dim=dim)

# NOTE: A linear column (dim=0) requires splitting the bias.
# It appears that row-columns do not require splitting the bias,
# as the final output without splitting is correct
if dim == 0:
if slice_bias is True:
if module.bias is not None:
module.bias.data = get_partition(module.bias, self.parallel_context, dim=0)

Expand Down

0 comments on commit 7675d05

Please sign in to comment.