diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 802c0aca7..9880b428f 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -16,14 +16,13 @@ try: except ModuleNotFoundError: HAS_APEX = False -# fmt: off @triton.jit def _layer_norm_fwd_fused( - Out, - A, - Weight, - Bias, + Out, + A, + Weight, + Bias, Mean, Rstd, stride, N, eps, BLOCK_SIZE: tl.constexpr, @@ -37,17 +36,17 @@ def _layer_norm_fwd_fused( _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(A + cols, mask=cols