diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index a03fa2cf6..802c0aca7 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -129,8 +129,8 @@ def _layer_norm_bwd_dwdb( db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): rows = i + tl.arange(0, BLOCK_SIZE_M) - mask = (rows[:, None] < M) & (cols[None, :] < N) - offs = rows[:, None] * N + cols[None, :] + mask = (rows[:, None] < M) & (cols[None,:] < N) + offs = rows[:, None] * N + cols[None,:] a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) mean = tl.load(Mean + rows, mask=rows