[TUTORIALS] Fixed formatting

This commit is contained in:
Philippe Tillet
2022-05-12 12:30:36 -07:00
parent cd30a99aa2
commit c736ba7c3e

View File

@@ -129,8 +129,8 @@ def _layer_norm_bwd_dwdb(
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
mask = (rows[:, None] < M) & (cols[None,:] < N)
offs = rows[:, None] * N + cols[None,:]
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
mean = tl.load(Mean + rows, mask=rows<M, other=0.)