[TUTORIALS] Fixed formatting
This commit is contained in:
@@ -129,8 +129,8 @@ def _layer_norm_bwd_dwdb(
|
|||||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for i in range(0, M, BLOCK_SIZE_M):
|
for i in range(0, M, BLOCK_SIZE_M):
|
||||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
mask = (rows[:, None] < M) & (cols[None,:] < N)
|
||||||
offs = rows[:, None] * N + cols[None, :]
|
offs = rows[:, None] * N + cols[None,:]
|
||||||
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
a = tl.load(A + offs, mask=mask, other=0.).to(tl.float32)
|
||||||
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
dout = tl.load(DOut + offs, mask=mask, other=0.).to(tl.float32)
|
||||||
mean = tl.load(Mean + rows, mask=rows<M, other=0.)
|
mean = tl.load(Mean + rows, mask=rows<M, other=0.)
|
||||||
|
Reference in New Issue
Block a user