[OPS][BLOCKSPARSE] Remove unnecessary loop and add cuda bool layout support (#425)

This commit is contained in:
Botao Yu
2022-01-12 03:07:16 +08:00
committed by GitHub
parent 94a2e10fe5
commit bf32205edc
2 changed files with 10 additions and 11 deletions

View File

@@ -283,14 +283,14 @@ def dsd_lut(layout, block, step, trans, device):
# ------------------------------- # -------------------------------
# same as above, except that the increments are in the sparse memory layout # same as above, except that the increments are in the sparse memory layout
if trans: if trans:
A_idx = torch.arange(num_blocks) A_idx = torch.arange(num_blocks, device=layout.device)
else: else:
A_idx = torch.tensor([], dtype=torch.int64, device=layout.device) A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
current_offset = 0 current_offset = 0
for z in range(layout.size(0)): for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone() layoutw = layout[z, :, :].clone().long()
msum = layoutw.sum() msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum) layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device)
A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1)) A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum current_offset += msum
A_incs = A_idx * block * block A_incs = A_idx * block * block

View File

@@ -106,19 +106,18 @@ def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexp
class _softmax(torch.autograd.Function): class _softmax(torch.autograd.Function):
@staticmethod @staticmethod
def make_lut(layout, block, device): def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
sizes = _empty.clone()
# sizes along rows # sizes along rows
for h in range(layout.shape[0]): sizes = layout.sum(-1).view(-1)
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
# offsets in block format # offsets in block format
offsets = torch.zeros_like(sizes) offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0) offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
# block indices # block indices
idx = torch.arange(layout.sum()) layout_sum = sizes.sum()
head = layout.nonzero(as_tuple=False)[:, 0] idx = torch.arange(layout_sum, device=layout.device)
rows = layout.nonzero(as_tuple=False)[:, 1] layout_nonzero = layout.nonzero(as_tuple=False)
columns = layout.nonzero(as_tuple=False)[:, 2] head = layout_nonzero[:, 0]
rows = layout_nonzero[:, 1]
columns = layout_nonzero[:, 2]
core = torch.stack((idx, columns, rows, head), dim=1).view(-1) core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
# construct look-up table # construct look-up table
offsets = offsets * 4 + 2 * sizes.numel() offsets = offsets * 4 + 2 * sizes.numel()