[TUTORIALS] Removed #noformat in layer norm tutorial
This commit is contained in:
@@ -16,7 +16,6 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
HAS_APEX = False
|
||||
|
||||
# fmt: off
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(
|
||||
@@ -65,6 +64,8 @@ def _layer_norm_fwd_fused(
|
||||
tl.store(Out + cols, out, mask=mask)
|
||||
|
||||
# Backward pass (DA + partial DW + partial DB)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dx_fused(
|
||||
_DA,
|
||||
@@ -235,6 +236,7 @@ class LayerNorm(torch.autograd.Function):
|
||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
torch.manual_seed(0)
|
||||
# create data
|
||||
|
Reference in New Issue
Block a user