[OPS][BLOCKSPARSE] Remove unnecessary loop and add cuda bool layout support (#425)
This commit is contained in:
@@ -283,14 +283,14 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
# -------------------------------
|
||||
# same as above, except that the increments are in the sparse memory layout
|
||||
if trans:
|
||||
A_idx = torch.arange(num_blocks)
|
||||
A_idx = torch.arange(num_blocks, device=layout.device)
|
||||
else:
|
||||
A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
current_offset = 0
|
||||
for z in range(layout.size(0)):
|
||||
layoutw = layout[z, :, :].clone()
|
||||
layoutw = layout[z, :, :].clone().long()
|
||||
msum = layoutw.sum()
|
||||
layoutw[layoutw > 0] = 1 + torch.arange(msum)
|
||||
layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device)
|
||||
A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
||||
current_offset += msum
|
||||
A_incs = A_idx * block * block
|
||||
|
@@ -106,19 +106,18 @@ def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexp
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
sizes = _empty.clone()
|
||||
# sizes along rows
|
||||
for h in range(layout.shape[0]):
|
||||
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
|
||||
sizes = layout.sum(-1).view(-1)
|
||||
# offsets in block format
|
||||
offsets = torch.zeros_like(sizes)
|
||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||
# block indices
|
||||
idx = torch.arange(layout.sum())
|
||||
head = layout.nonzero(as_tuple=False)[:, 0]
|
||||
rows = layout.nonzero(as_tuple=False)[:, 1]
|
||||
columns = layout.nonzero(as_tuple=False)[:, 2]
|
||||
layout_sum = sizes.sum()
|
||||
idx = torch.arange(layout_sum, device=layout.device)
|
||||
layout_nonzero = layout.nonzero(as_tuple=False)
|
||||
head = layout_nonzero[:, 0]
|
||||
rows = layout_nonzero[:, 1]
|
||||
columns = layout_nonzero[:, 2]
|
||||
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
|
||||
# construct look-up table
|
||||
offsets = offsets * 4 + 2 * sizes.numel()
|
||||
|
Reference in New Issue
Block a user