[TUTORIALS] Removed #noformat in layer norm tutorial

This commit is contained in:
Philippe Tillet
2022-05-12 12:41:25 -07:00
parent c736ba7c3e
commit 0835a4fb05

View File

@@ -16,7 +16,6 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
HAS_APEX = False HAS_APEX = False
# fmt: off
@triton.jit @triton.jit
def _layer_norm_fwd_fused( def _layer_norm_fwd_fused(
@@ -65,6 +64,8 @@ def _layer_norm_fwd_fused(
tl.store(Out + cols, out, mask=mask) tl.store(Out + cols, out, mask=mask)
# Backward pass (DA + partial DW + partial DB) # Backward pass (DA + partial DW + partial DB)
@triton.jit @triton.jit
def _layer_norm_bwd_dx_fused( def _layer_norm_bwd_dx_fused(
_DA, _DA,
@@ -235,6 +236,7 @@ class LayerNorm(torch.autograd.Function):
def layer_norm(a, normalized_shape, weight, bias, eps): def layer_norm(a, normalized_shape, weight, bias, eps):
return LayerNorm.apply(a, normalized_shape, weight, bias, eps) return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
torch.manual_seed(0) torch.manual_seed(0)
# create data # create data