Skip to content

Commit

Permalink
add env to support different mm layout.
Browse files Browse the repository at this point in the history
  • Loading branch information
Reinerzhou committed Nov 29, 2024
1 parent 3913ead commit ed2d5e4
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
41 changes: 41 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/maca/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Optional

import torch

from lmdeploy.pytorch.kernels.dlinfer import linear

from ...linear import LinearBuilder, LinearImpl


class DlinferMacaLinearImpl(LinearImpl):
"""Dlinfer linear implementation api."""

def update_weights(self,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
"""update weights."""
if os.getenv('MACA_USE_NN_LAYOUT', 'True').lower() == 'true':
weight = weight.data.t().contiguous()
return weight, bias

def forward(self,
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
"""forward."""
return linear(x, weight, bias, all_reduce)


class DlinferMacaLinearBuilder(LinearBuilder):
"""Dlinfer linear implementation builder."""

@staticmethod
def build(in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None):
"""build."""
return DlinferMacaLinearImpl()
36 changes: 36 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from lmdeploy.utils import get_logger

from ...base import OpType
from ..op_backend import DlinferOpsBackend

logger = get_logger('lmdeploy')
Expand All @@ -18,6 +19,41 @@ def get_name() -> str:
"""backend name."""
return 'maca'

@classmethod
def get_layer_impl_builder(cls, layer_type: OpType):
"""get dlinfer layer builder."""
if layer_type == OpType.Attention:
from ..attention import DlinferAttentionBuilder
return DlinferAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
from ..apply_rotary_emb import DlinferApplyRotaryEmbBuilder
return DlinferApplyRotaryEmbBuilder
elif layer_type == OpType.SiluAndMul:
from ..activation import DlinferSiluAndMulBuilder
return DlinferSiluAndMulBuilder
elif layer_type == OpType.RMSNorm:
from ..norm import DlinferRMSNormBuilder
return DlinferRMSNormBuilder
elif layer_type == OpType.SoftmaxTopK:
from ..moe import DlinferSoftmaxTopKBuilder
return DlinferSoftmaxTopKBuilder
elif layer_type == OpType.FusedMoE:
from ..moe import DlinferFusedMoEBuilder
return DlinferFusedMoEBuilder
elif layer_type == OpType.Linear:
from .linear import DlinferMacaLinearBuilder
return DlinferMacaLinearBuilder
elif layer_type == OpType.LinearW4A16:
from ..awq_modules import AwqLinearW4A16Builder
return AwqLinearW4A16Builder
elif layer_type == OpType.RotaryEmbedding:
from ..rotary_embedding import DlinferRotaryEmbeddingBuilder
return DlinferRotaryEmbeddingBuilder
else:
logger.debug(
f'Op {layer_type} fallback to default implementation.')
return super().get_layer_impl_builder(layer_type)

@staticmethod
def get_k_block_shape(
block_size: int,
Expand Down

0 comments on commit ed2d5e4

Please sign in to comment.