Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support deepspeed sequence parallel #31525

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,24 @@ def trainer_config_process(self, args, auto_find_batch_size=False):
Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object
creation.
"""

if getattr(args, "sequence_parallel", 1) > 1:
assert (
is_accelerate_available()
), "DeepSpeed sequence parallelism requires Accelerate, install it with 'pip install accelerate'"

from accelerate.utils import parallel_state as mpu

mpu.initialize_model_parallel(
sequence_parallel_size=args.sequence_parallel,
)
world_size = mpu.get_data_parallel_world_size()
else:
world_size = args.world_size

# DeepSpeed does:
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
train_batch_size = world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
self.fill_match(
"train_micro_batch_size_per_gpu",
args.per_device_train_batch_size,
Expand Down
91 changes: 91 additions & 0 deletions src/transformers/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Any, Tuple

import torch
from torch import Tensor
from torch import distributed as dist
from torch.nn import Module


def single_all_to_all(input, scatter_idx, gather_idx, group):
if dist.get_world_size(group) <= 1:
return input
seq_world_size = dist.get_world_size(group)

input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
# TODO: use all_to_all_single instead
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_idx).contiguous()


class _SeqAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx

return single_all_to_all(input, scatter_idx, gather_idx, group)

@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)


class DistributedAttention(torch.nn.Module):
"""Initialization.

Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm

support and only support shape [b, s, h, d]
"""

def __init__(
self,
local_attention: Module,
sequence_process_group: dist.ProcessGroup,
scatter_idx: int = 2, # scatter dim, commonly the head dim
gather_idx: int = 1, # gather dim, commonly the sequence dim
) -> None:
super(DistributedAttention, self).__init__()

self.local_attn = local_attention
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx

def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
"""forward

Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args

Returns:
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# in shape : e.g., [b,s/p:h:]
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)

# out shape : e.g., [b,s:h/p:]
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)

output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)

# out e.g., [b,s/p::h]
return output
19 changes: 18 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
save_offload_index,
set_module_tensor_to_device,
)
from accelerate.utils import parallel_state as mpu

accelerate_version = version.parse(importlib.metadata.version("accelerate"))
if accelerate_version >= version.parse("0.31"):
Expand Down Expand Up @@ -3705,7 +3706,19 @@ def from_pretrained(
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
if is_accelerate_available() and mpu.model_parallel_is_initialized():
mpu_module = mpu
sequence_data_parallel_group = mpu.get_sequence_data_parallel_group()
else:
mpu_module = None
sequence_data_parallel_group = None
init_contexts = [
deepspeed.zero.Init(
sequence_data_parallel_group=sequence_data_parallel_group,
config_dict_or_path=deepspeed_config(),
mpu=mpu_module,
)
] + init_contexts
elif low_cpu_mem_usage:
init_contexts.append(init_empty_weights())

Expand Down Expand Up @@ -3892,6 +3905,10 @@ def from_pretrained(
)
pass

if is_accelerate_available() and mpu.get_sequence_parallel_world_size_or_one() > 1:
if not getattr(model, "supports_sequence_parallel", False):
raise ValueError("The model does not support sequence parallelism.")

# Dispatch model with hooks on all devices if necessary
if device_map is not None:
device_map_kwargs = {
Expand Down
48 changes: 42 additions & 6 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,19 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig


if is_accelerate_available():
from accelerate.utils import parallel_state as mpu

from ...layer import DistributedAttention

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "LlamaConfig"
Expand Down Expand Up @@ -374,6 +380,13 @@ def __init__(self, *args, **kwargs):
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

if is_accelerate_available() and mpu.sequence_parallel_is_enabled():
self.attn_func = DistributedAttention(_flash_attention_forward, mpu.get_sequence_parallel_group())
self.q_len_multiplier = mpu.get_sequence_parallel_world_size()
else:
self.attn_func = _flash_attention_forward
self.q_len_multiplier = 1

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -447,12 +460,12 @@ def forward(
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

attn_output = _flash_attention_forward(
attn_output = self.attn_func(
query_states,
key_states,
value_states,
attention_mask,
q_len,
q_len * self.q_len_multiplier,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand All @@ -475,6 +488,21 @@ class LlamaSdpaAttention(LlamaAttention):
SDPA API.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if is_accelerate_available() and mpu.sequence_parallel_is_enabled():
self.attn_func = DistributedAttention(
torch.nn.functional.scaled_dot_product_attention,
mpu.get_sequence_parallel_group(),
scatter_idx=1,
gather_idx=2,
)
self.q_len_multiplier = mpu.get_sequence_parallel_world_size()
else:
self.attn_func = torch.nn.functional.scaled_dot_product_attention
self.q_len_multiplier = 1

# Adapted from LlamaAttention.forward
def forward(
self,
Expand Down Expand Up @@ -526,7 +554,7 @@ def forward(

causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
causal_mask = causal_mask[:, :, :, : key_states.shape[-2] * self.q_len_multiplier]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
Expand All @@ -539,7 +567,7 @@ def forward(
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
attn_output = self.attn_func(
query_states,
key_states,
value_states,
Expand Down Expand Up @@ -663,6 +691,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
supports_sequence_parallel = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
Expand Down Expand Up @@ -781,6 +810,11 @@ def __init__(self, config: LlamaConfig):
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False

if is_accelerate_available() and mpu.sequence_parallel_is_enabled():
self.q_len_multiplier = mpu.get_sequence_parallel_world_size()
else:
self.q_len_multiplier = 1

# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -837,7 +871,9 @@ def forward(
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1] * self.q_len_multiplier,
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
Expand Down Expand Up @@ -943,7 +979,7 @@ def _update_causal_mask(

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
sequence_length = input_tensor.shape[1] * self.q_len_multiplier
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
Expand Down
Loading