diff --git a/lmdeploy/pytorch/backends/dlinfer/maca/linear.py b/lmdeploy/pytorch/backends/dlinfer/maca/linear.py new file mode 100644 index 000000000..28102e9ad --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/maca/linear.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Optional + +import torch + +from lmdeploy.pytorch.kernels.dlinfer import linear + +from ...linear import LinearBuilder, LinearImpl + + +class DlinferMacaLinearImpl(LinearImpl): + """Dlinfer linear implementation api.""" + + def update_weights(self, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """update weights.""" + if os.getenv('MACA_USE_NN_LAYOUT', 'True').lower() == 'true': + weight = weight.data.t().contiguous() + return weight, bias + + def forward(self, + x, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): + """forward.""" + return linear(x, weight, bias, all_reduce) + + +class DlinferMacaLinearBuilder(LinearBuilder): + """Dlinfer linear implementation builder.""" + + @staticmethod + def build(in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None): + """build.""" + return DlinferMacaLinearImpl() diff --git a/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py index 084cae1bf..d5b0f6cd1 100644 --- a/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py @@ -5,6 +5,7 @@ from lmdeploy.utils import get_logger +from ...base import OpType from ..op_backend import DlinferOpsBackend logger = get_logger('lmdeploy') @@ -18,6 +19,41 @@ def get_name() -> str: """backend name.""" return 'maca' + @classmethod + def get_layer_impl_builder(cls, layer_type: OpType): + """get dlinfer layer builder.""" + if layer_type == OpType.Attention: + from ..attention import DlinferAttentionBuilder + return DlinferAttentionBuilder + elif layer_type == OpType.ApplyRotaryEmb: + from ..apply_rotary_emb import DlinferApplyRotaryEmbBuilder + return DlinferApplyRotaryEmbBuilder + elif layer_type == OpType.SiluAndMul: + from ..activation import DlinferSiluAndMulBuilder + return DlinferSiluAndMulBuilder + elif layer_type == OpType.RMSNorm: + from ..norm import DlinferRMSNormBuilder + return DlinferRMSNormBuilder + elif layer_type == OpType.SoftmaxTopK: + from ..moe import DlinferSoftmaxTopKBuilder + return DlinferSoftmaxTopKBuilder + elif layer_type == OpType.FusedMoE: + from ..moe import DlinferFusedMoEBuilder + return DlinferFusedMoEBuilder + elif layer_type == OpType.Linear: + from .linear import DlinferMacaLinearBuilder + return DlinferMacaLinearBuilder + elif layer_type == OpType.LinearW4A16: + from ..awq_modules import AwqLinearW4A16Builder + return AwqLinearW4A16Builder + elif layer_type == OpType.RotaryEmbedding: + from ..rotary_embedding import DlinferRotaryEmbeddingBuilder + return DlinferRotaryEmbeddingBuilder + else: + logger.debug( + f'Op {layer_type} fallback to default implementation.') + return super().get_layer_impl_builder(layer_type) + @staticmethod def get_k_block_shape( block_size: int,