Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the max error of moe_permute/unpermute.grad could reach 3.6e+00 #1336

Open
NiuMa-1234 opened this issue Nov 15, 2024 · 1 comment
Open

the max error of moe_permute/unpermute.grad could reach 3.6e+00 #1336

NiuMa-1234 opened this issue Nov 15, 2024 · 1 comment

Comments

@NiuMa-1234
Copy link

NiuMa-1234 commented Nov 15, 2024

Hi, I compared the results between the operators of moe_permute/unpermute and PyTorch permute/unpermute with different data precisions, and found that they are not consistent and the maximum error of grad could be e+00. This might impact the performance of models and I wonder if you plan to improve this. Below is a snippet of test code and the detailed output.
More specifically, The max error of Unpermute oprators is more severe than Permute. And as the topK grows, the error increases.
For instance, under the data accuracy of bf16, the error of Unpermute backwards could reach 8.9e-01, the FWD period of it is 7.8e-03. Meanwhile, the BWD of Permute could reach 3.7e+00 although the FWD is the same. This has had some impact on training loss.

Unit test Code:


import torch
import triton
import torch.cuda.nvtx as nvtx
from transformer_engine.pytorch import moe_permute as te_permute, moe_unpermute as te_unpermute
def permute(tokens, indices, num_out_tokens: int = None, padded_mode: bool = False):
    if indices.dim() == 1:
        topk = 1
    else:
        topk = indices.size(1)
    flatten_indices = indices.view(-1)
    sorted_indices = torch.argsort(flatten_indices, stable=True)
    if num_out_tokens is not None:
        sorted_indices = sorted_indices[:num_out_tokens]
    permuted_tokens = tokens.index_select(0, sorted_indices // topk)
    return permuted_tokens, sorted_indices
def unpermute(
    permuted_tokens: torch.Tensor,
    sorted_indices: torch.Tensor,
    probs: torch.Tensor = None,
    padded_mode: bool = False,
    restore_shape: torch.Size = None,
):

    assert sorted_indices.numel() == permuted_tokens.size(0)
    if probs is not None:
        num_unpermuted_tokens = probs.numel()
        topk = probs.size(1)
    else:
        num_unpermuted_tokens = permuted_tokens.size(0)
        topk = 1

    unpermuted_tokens = torch.zeros(
        [num_unpermuted_tokens, permuted_tokens.shape[-1]],
        dtype=permuted_tokens.dtype,
        device=permuted_tokens.device,
    )
    sorted_indices = sorted_indices.long()
    unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
    unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
    if probs is not None:
        unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
    unpermuted_tokens = unpermuted_tokens.sum(dim=1)
    return unpermuted_tokens
def permute_topK_test(
    dtype,
    num_token,
    num_expert,
    hidden_size,
    num_topK,
    PRINT,
    BENCHMARK):

    num_out_tokens=num_token*num_topK
    print(f"\n{dtype} token:{num_token} hidden_size:{hidden_size} expert:{num_expert} topK:{num_topK}")

    is_fp8 = dtype in [torch.float8_e5m2, torch.float8_e4m3fn]

    permute_input = torch.rand((num_token, hidden_size), dtype=torch.float32).cuda()

    permute_input = permute_input.to(dtype)
    if is_fp8:
        permute_input = permute_input.half()

    permute_input.requires_grad_(True)
    te_permute_input = permute_input.detach().to(dtype)
    te_permute_input.requires_grad_(True)

    if num_token > 0:
        indices = torch.stack([torch.randperm(num_expert)[:num_topK] for _ in range(num_token)])
    else:
        indices = torch.empty((num_token, num_topK))
    indices = indices.to(torch.int32).cuda()
    probs = torch.rand(num_token, num_topK).cuda()
    row_sums = probs.sum(dim=1, keepdim=True)
    probs = probs / row_sums
    probs.requires_grad_(True)

    print("====================== Permute ======================")

    permute_output, sorted_indices = permute(permute_input, indices,num_out_tokens=num_out_tokens)
    permute_bwd_input = torch.rand_like(permute_output)
    permute_output.backward(permute_bwd_input, retain_graph=True)

    
    te_permute_output, te_row_id_map = te_permute(te_permute_input, indices, num_out_tokens=num_out_tokens)
    te_permute_bwd_input = torch.rand_like(te_permute_output)
    te_permute_output.backward(te_permute_bwd_input, retain_graph=True)

    if torch.allclose(permute_output.float(), te_permute_output.float()) == False:
        original_inputs = te_permute_output.float().cpu().numpy().flatten()
        original_output = permute_output.float().cpu().numpy().flatten()
        max_abs_error = abs(original_inputs - original_output).max()
        print(f"[permute fwd] (te vs pytorch): \t\t\t{max_abs_error:.3e} ({dtype})")
    else:
        print(f"[permute fwd] (te vs pytorch) equal")

    if torch.allclose(permute_input.grad.float(), te_permute_input.grad.float()) == False:
        original_inputs = te_permute_input.grad.float().cpu().numpy().flatten()
        original_output = permute_input.grad.float().cpu().numpy().flatten()
        max_abs_error = abs(original_inputs - original_output).max()
        print(f"[permute bwd] (te vs pytorch): \t\t\t{max_abs_error:.3e} ({dtype})")
    else:
        print(f"[permute bwd] (te vs pytorch) equal")
    
     
    print("====================== Unpermute ======================")
    unpermute_input = permute_output.detach()
    unpermute_input.requires_grad_(True)

    
    te_unpermute_input = te_permute_output.detach()
    te_unpermute_input.requires_grad_(True)
    te_probs = probs.detach()
    te_probs.requires_grad_(True)
    
    

    unpermute_output = unpermute(unpermute_input, sorted_indices, probs=probs)
    unpermute_bwd_input = torch.rand_like(unpermute_output)
    unpermute_output.backward(unpermute_bwd_input, retain_graph=True)

    
    te_unpermute_output = te_unpermute(te_unpermute_input, te_row_id_map, te_probs)
    te_unpermute_bwd_input = torch.rand_like(te_unpermute_output)
    te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)

    if torch.allclose(unpermute_output.float(), te_unpermute_output.float()) == False:
        original_inputs = unpermute_output.float().cpu().detach().numpy().flatten()
        original_output = te_unpermute_output.float().cpu().detach().numpy().flatten()
        max_abs_error = abs(original_inputs - original_output).max()
        print(f"[unpermute fwd] max error (te vs pytorch): \t{max_abs_error:.3e} ({dtype})")
    else:
        print(f"[unpermute fwd] (te vs pytorch) equal")

    if torch.allclose(unpermute_input.grad.float(), te_unpermute_input.grad.float()) == False:
        original_inputs = te_unpermute_input.grad.float().cpu().detach().numpy().flatten()
        original_output = unpermute_input.grad.float().cpu().detach().numpy().flatten()
        max_abs_error = abs(original_inputs - original_output).max()
        print(f"[unpermute bwd] max error (te vs pytorch): \t{max_abs_error:.3e} ({dtype})")
    else:
        print(f"[unpermute bwd] (te vs pytorch) equal")

def test_permute_topK():

    torch.manual_seed(1)

    num_token = 4096 * 2
    num_expert = 16
    hidden_size = 2048
    num_topK = 4

    PRINT=False
    Benchmark = False
    print("GPU:", torch.cuda.get_device_name(0))

    dtype = torch.float32
    permute_topK_test(dtype, num_token, num_expert,
                      hidden_size, num_topK, PRINT, Benchmark)
    dtype = torch.float16
    permute_topK_test(dtype, num_token, num_expert,
                      hidden_size, num_topK, False, Benchmark)
    dtype = torch.bfloat16
    permute_topK_test(dtype, num_token, num_expert,
                      hidden_size, num_topK, False, Benchmark)
    dtype = torch.float8_e5m2
    permute_topK_test(dtype, num_token, num_expert,
                      hidden_size, num_topK, False, Benchmark)
    dtype = torch.float8_e4m3fn
    permute_topK_test(dtype, num_token, num_expert,
                      hidden_size, num_topK, False, Benchmark)
    dtype = torch.bfloat16
    permute_topK_test(dtype, num_token, 4, hidden_size, 1, False, Benchmark)
    permute_topK_test(dtype, num_token, 5, hidden_size, 2, False, Benchmark)
    permute_topK_test(dtype, num_token, 6, hidden_size, 3, False, Benchmark)
    permute_topK_test(dtype, num_token, 7, hidden_size, 4, False, Benchmark)
    permute_topK_test(dtype, num_token, 8, hidden_size, 5, False, Benchmark)
    num_token = 0
    permute_topK_test(dtype, num_token, 8, hidden_size, 5, False, Benchmark)

if __name__ == "__main__":
    test_permute_topK()`


And here is the output of max error:

`torch.float32 token:8192 hidden_size:2048 expert:16 topK:4
====================== Permute ======================
[permute fwd] (te vs pytorch) equal
[permute bwd] (te vs pytorch):                  3.611e+00 (torch.float32)
====================== Unpermute ======================
[unpermute fwd] (te vs pytorch) equal
[unpermute bwd] max error (te vs pytorch):      8.805e-01 (torch.float32)

torch.float16 token:8192 hidden_size:2048 expert:16 topK:4
====================== Permute ======================
[permute fwd] (te vs pytorch) equal
[permute bwd] (te vs pytorch):                  3.593e+00 (torch.float16)
====================== Unpermute ======================
[unpermute fwd] max error (te vs pytorch):      9.767e-04 (torch.float16)
[unpermute bwd] max error (te vs pytorch):      9.252e-01 (torch.float16)

torch.bfloat16 token:8192 hidden_size:2048 expert:16 topK:4
====================== Permute ======================
[permute fwd] (te vs pytorch) equal
[permute bwd] (te vs pytorch):                  3.732e+00 (torch.bfloat16)
====================== Unpermute ======================
[unpermute fwd] max error (te vs pytorch):      7.813e-03 (torch.bfloat16)
[unpermute bwd] max error (te vs pytorch):      8.945e-01 (torch.bfloat16)`

The performance of this operator varies among dtypes and it has influenced the training model. I tried to combine the moe_permute with pyTorch_unpermute, but the max error of grads exists. Could you please help me ? Any suggestion would be appreciated.

@StudyingShao
Copy link
Contributor

Hi @NiuMa-1234 , the diff observed between the unpermute kernels of pytorch and TE can be attributed to rounding errors.

The unpermute process involves reduction calculations, and different implementations will naturally yield results with slight differences. And it certainly increases with topK, as it reduces along the topK dimension.

If you want a "higher precision", you can try changing TCompute in this file to float32. While I dont think this will reduce the gap between the results from pytorch and TE, but it will certainly increase the precision of the operators from TE.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants