From 99b3d265555cd5df661894b64ebee50cdd85967c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 22 Nov 2024 14:47:45 -0800 Subject: [PATCH] [torchlib] Fix aten_mean_dim (#1962) Fix when the rank of `dim` is not known, the conditional will pick the wrong branch because the truthiness of a value is always True. --- onnxscript/function_libs/torch_lib/ops/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a955583e9..4c22df181 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5225,9 +5225,8 @@ def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: if IsScalar(self): result = self else: - if IsScalar(dim): - dim = op.Unsqueeze(dim, axes=0) - result = op.ReduceMean(self, dim, keepdims=keepdim) + dims = op.Reshape(dim, op.Constant(value_ints=[-1])) + result = op.ReduceMean(self, dims, keepdims=keepdim) return result