diff --git a/lmdeploy/pytorch/backends/dlinfer/linear.py b/lmdeploy/pytorch/backends/dlinfer/linear.py index 567a01ddd..f2450a974 100644 --- a/lmdeploy/pytorch/backends/dlinfer/linear.py +++ b/lmdeploy/pytorch/backends/dlinfer/linear.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os from typing import Optional import torch @@ -11,6 +12,14 @@ class DlinferLinearImpl(LinearImpl): """Dlinfer linear implementation api.""" + def update_weights(self, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """update weights.""" + if os.getenv('TORCH_MACA_NN_LAYOUT', 'False').lower() == 'true': + weight = weight.data.t().contiguous() + return weight, bias + def forward(self, x, weight: torch.Tensor,