Skip to content

Commit

Permalink
Fix laynorm backward bug (#7164)
Browse files Browse the repository at this point in the history
* fix layernorm backward index bug

* add layernorm test case

* auto format by CI

Co-authored-by: oneflow-ci-bot <[email protected]>
Co-authored-by: oneflow-ci-bot <[email protected]>
  • Loading branch information
3 people authored Dec 31, 2021
1 parent 2eae496 commit de9fc41
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion oneflow/core/autograd/gradient_funcs/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ Maybe<void> LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple
int64_t begin_norm_axis = ctx->begin_norm_axis;
if (begin_norm_axis < 0) { begin_norm_axis += dy->shape()->NumAxes(); }

std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
if (!ctx->has_affine) {
// Use LayerNormParamGrad(Tensor dy, Tensor gamma, Int64 begin_params_axis, Double epsilon).
dy = JUST(functional::LayerNormParamGrad(dy, begin_params_axis, ctx->epsilon));
} else {
// Use LayerNormAffineParamGrad(Tensor dy, Tensor gamma, Tensor normalized, Int64
// begin_params_axis, Double epsilon).
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
std::shared_ptr<Tensor> normalized = saved_tensors.at(ctx->normalized_index);
const auto& results = JUST(functional::LayerNormAffineParamGrad(
dy, gamma, normalized, begin_params_axis, ctx->epsilon));
Expand Down
18 changes: 18 additions & 0 deletions python/oneflow/test/modules/test_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,24 @@ def get_random_norm_shape():
y = m(x)
return y

@autotest(n=20, auto_backward=True, rtol=1e-3, atol=1e-3)
def test_layernorm_without_affine(test_case):
device = random_device()
channel = random(1, 200).to(int)
height = random(1, 2).to(int)
width = random(8192, 32768).to(int)

def get_random_norm_shape():
begin_axis = random(1, 3).to(int).value()
return tuple((channel.value(), height.value(), width.value())[begin_axis:])

m = torch.nn.LayerNorm(normalized_shape=get_random_norm_shape()).to(device)
x = random_pytorch_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to(
device
)
y = m(x)
return y


if __name__ == "__main__":
unittest.main()

0 comments on commit de9fc41

Please sign in to comment.