[TUTORIALS] Removed #noformat in layer norm tutorial
This commit is contained in:
@@ -16,7 +16,6 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
HAS_APEX = False
|
HAS_APEX = False
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _layer_norm_fwd_fused(
|
def _layer_norm_fwd_fused(
|
||||||
@@ -65,6 +64,8 @@ def _layer_norm_fwd_fused(
|
|||||||
tl.store(Out + cols, out, mask=mask)
|
tl.store(Out + cols, out, mask=mask)
|
||||||
|
|
||||||
# Backward pass (DA + partial DW + partial DB)
|
# Backward pass (DA + partial DW + partial DB)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _layer_norm_bwd_dx_fused(
|
def _layer_norm_bwd_dx_fused(
|
||||||
_DA,
|
_DA,
|
||||||
@@ -235,6 +236,7 @@ class LayerNorm(torch.autograd.Function):
|
|||||||
def layer_norm(a, normalized_shape, weight, bias, eps):
|
def layer_norm(a, normalized_shape, weight, bias, eps):
|
||||||
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
|
return LayerNorm.apply(a, normalized_shape, weight, bias, eps)
|
||||||
|
|
||||||
|
|
||||||
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
# create data
|
# create data
|
||||||
|
Reference in New Issue
Block a user