From bf32205edc00ecffc557884ae4699fae4ce91e63 Mon Sep 17 00:00:00 2001 From: Botao Yu Date: Wed, 12 Jan 2022 03:07:16 +0800 Subject: [PATCH] [OPS][BLOCKSPARSE] Remove unnecessary loop and add cuda bool layout support (#425) --- python/triton/ops/blocksparse/matmul.py | 6 +++--- python/triton/ops/blocksparse/softmax.py | 15 +++++++-------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 9a04ded66..48efe7ea3 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -283,14 +283,14 @@ def dsd_lut(layout, block, step, trans, device): # ------------------------------- # same as above, except that the increments are in the sparse memory layout if trans: - A_idx = torch.arange(num_blocks) + A_idx = torch.arange(num_blocks, device=layout.device) else: A_idx = torch.tensor([], dtype=torch.int64, device=layout.device) current_offset = 0 for z in range(layout.size(0)): - layoutw = layout[z, :, :].clone() + layoutw = layout[z, :, :].clone().long() msum = layoutw.sum() - layoutw[layoutw > 0] = 1 + torch.arange(msum) + layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device) A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1)) current_offset += msum A_incs = A_idx * block * block diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index f9d49ae56..6ac76dcc4 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -106,19 +106,18 @@ def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexp class _softmax(torch.autograd.Function): @staticmethod def make_lut(layout, block, device): - _empty = torch.tensor([], dtype=torch.int64, device=layout.device) - sizes = _empty.clone() # sizes along rows - for h in range(layout.shape[0]): - sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) + sizes = layout.sum(-1).view(-1) # offsets in block format offsets = torch.zeros_like(sizes) offsets[1:] = torch.cumsum(sizes[:-1], dim=0) # block indices - idx = torch.arange(layout.sum()) - head = layout.nonzero(as_tuple=False)[:, 0] - rows = layout.nonzero(as_tuple=False)[:, 1] - columns = layout.nonzero(as_tuple=False)[:, 2] + layout_sum = sizes.sum() + idx = torch.arange(layout_sum, device=layout.device) + layout_nonzero = layout.nonzero(as_tuple=False) + head = layout_nonzero[:, 0] + rows = layout_nonzero[:, 1] + columns = layout_nonzero[:, 2] core = torch.stack((idx, columns, rows, head), dim=1).view(-1) # construct look-up table offsets = offsets * 4 + 2 * sizes.numel()