You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I compared the results between the operators of moe_permute/unpermute and PyTorch permute/unpermute with different data precisions, and found that they are not consistent and the maximum error of grad could be e+00. This might impact the performance of models and I wonder if you plan to improve this. Below is a snippet of test code and the detailed output.
More specifically, The max error of Unpermute oprators is more severe than Permute. And as the topK grows, the error increases.
For instance, under the data accuracy of bf16, the error of Unpermute backwards could reach 8.9e-01, the FWD period of it is 7.8e-03. Meanwhile, the BWD of Permute could reach 3.7e+00 although the FWD is the same. This has had some impact on training loss.
`torch.float32 token:8192 hidden_size:2048 expert:16 topK:4
====================== Permute ======================
[permute fwd] (te vs pytorch) equal
[permute bwd] (te vs pytorch): 3.611e+00 (torch.float32)
====================== Unpermute ======================
[unpermute fwd] (te vs pytorch) equal
[unpermute bwd] max error (te vs pytorch): 8.805e-01 (torch.float32)
torch.float16 token:8192 hidden_size:2048 expert:16 topK:4
====================== Permute ======================
[permute fwd] (te vs pytorch) equal
[permute bwd] (te vs pytorch): 3.593e+00 (torch.float16)
====================== Unpermute ======================
[unpermute fwd] max error (te vs pytorch): 9.767e-04 (torch.float16)
[unpermute bwd] max error (te vs pytorch): 9.252e-01 (torch.float16)
torch.bfloat16 token:8192 hidden_size:2048 expert:16 topK:4
====================== Permute ======================
[permute fwd] (te vs pytorch) equal
[permute bwd] (te vs pytorch): 3.732e+00 (torch.bfloat16)
====================== Unpermute ======================
[unpermute fwd] max error (te vs pytorch): 7.813e-03 (torch.bfloat16)
[unpermute bwd] max error (te vs pytorch): 8.945e-01 (torch.bfloat16)`
The performance of this operator varies among dtypes and it has influenced the training model. I tried to combine the moe_permute with pyTorch_unpermute, but the max error of grads exists. Could you please help me ? Any suggestion would be appreciated.
The text was updated successfully, but these errors were encountered:
Hi @NiuMa-1234 , the diff observed between the unpermute kernels of pytorch and TE can be attributed to rounding errors.
The unpermute process involves reduction calculations, and different implementations will naturally yield results with slight differences. And it certainly increases with topK, as it reduces along the topK dimension.
If you want a "higher precision", you can try changing TCompute in this file to float32. While I dont think this will reduce the gap between the results from pytorch and TE, but it will certainly increase the precision of the operators from TE.
Hi, I compared the results between the operators of moe_permute/unpermute and PyTorch permute/unpermute with different data precisions, and found that they are not consistent and the maximum error of grad could be e+00. This might impact the performance of models and I wonder if you plan to improve this. Below is a snippet of test code and the detailed output.
More specifically, The max error of Unpermute oprators is more severe than Permute. And as the topK grows, the error increases.
For instance, under the data accuracy of bf16, the error of Unpermute backwards could reach 8.9e-01, the FWD period of it is 7.8e-03. Meanwhile, the BWD of Permute could reach 3.7e+00 although the FWD is the same. This has had some impact on training loss.
Unit test Code:
And here is the output of max error:
The performance of this operator varies among dtypes and it has influenced the training model. I tried to combine the moe_permute with pyTorch_unpermute, but the max error of grads exists. Could you please help me ? Any suggestion would be appreciated.
The text was updated successfully, but these errors were encountered: