[TUTORIALS] Fixed formatting
This commit is contained in:
@@ -129,8 +129,8 @@ def _layer_norm_bwd_dwdb(
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
mask = (rows[:, None] < M) & (cols[None,:] < N)
|
||||
offs = rows[:, None] * N + cols[None,:]
|
||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||
mean = tl.load(Mean + rows, mask=rows<M, other=0.)
|
||||
|
Reference in New Issue
Block a user