From c736ba7c3e170b20c203bf7ef4616f931acab2f5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 12 May 2022 12:30:36 -0700 Subject: [PATCH] [TUTORIALS] Fixed formatting --- python/tutorials/05-layer-norm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index a03fa2cf6..802c0aca7 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -129,8 +129,8 @@ def _layer_norm_bwd_dwdb( db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): rows = i + tl.arange(0, BLOCK_SIZE_M) - mask = (rows[:, None] < M) & (cols[None, :] < N) - offs = rows[:, None] * N + cols[None, :] + mask = (rows[:, None] < M) & (cols[None,:] < N) + offs = rows[:, None] * N + cols[None,:] a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32) dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32) mean = tl.load(Mean + rows, mask=rows