[OPS][BLOCKSPARSE] Remove unnecessary loop and add cuda bool layout support (#425)
This commit is contained in:
@@ -283,14 +283,14 @@ def dsd_lut(layout, block, step, trans, device):
|
|||||||
# -------------------------------
|
# -------------------------------
|
||||||
# same as above, except that the increments are in the sparse memory layout
|
# same as above, except that the increments are in the sparse memory layout
|
||||||
if trans:
|
if trans:
|
||||||
A_idx = torch.arange(num_blocks)
|
A_idx = torch.arange(num_blocks, device=layout.device)
|
||||||
else:
|
else:
|
||||||
A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
|
A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||||
current_offset = 0
|
current_offset = 0
|
||||||
for z in range(layout.size(0)):
|
for z in range(layout.size(0)):
|
||||||
layoutw = layout[z, :, :].clone()
|
layoutw = layout[z, :, :].clone().long()
|
||||||
msum = layoutw.sum()
|
msum = layoutw.sum()
|
||||||
layoutw[layoutw > 0] = 1 + torch.arange(msum)
|
layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device)
|
||||||
A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
||||||
current_offset += msum
|
current_offset += msum
|
||||||
A_incs = A_idx * block * block
|
A_incs = A_idx * block * block
|
||||||
|
@@ -106,19 +106,18 @@ def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexp
|
|||||||
class _softmax(torch.autograd.Function):
|
class _softmax(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_lut(layout, block, device):
|
def make_lut(layout, block, device):
|
||||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
|
||||||
sizes = _empty.clone()
|
|
||||||
# sizes along rows
|
# sizes along rows
|
||||||
for h in range(layout.shape[0]):
|
sizes = layout.sum(-1).view(-1)
|
||||||
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
|
|
||||||
# offsets in block format
|
# offsets in block format
|
||||||
offsets = torch.zeros_like(sizes)
|
offsets = torch.zeros_like(sizes)
|
||||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||||
# block indices
|
# block indices
|
||||||
idx = torch.arange(layout.sum())
|
layout_sum = sizes.sum()
|
||||||
head = layout.nonzero(as_tuple=False)[:, 0]
|
idx = torch.arange(layout_sum, device=layout.device)
|
||||||
rows = layout.nonzero(as_tuple=False)[:, 1]
|
layout_nonzero = layout.nonzero(as_tuple=False)
|
||||||
columns = layout.nonzero(as_tuple=False)[:, 2]
|
head = layout_nonzero[:, 0]
|
||||||
|
rows = layout_nonzero[:, 1]
|
||||||
|
columns = layout_nonzero[:, 2]
|
||||||
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
|
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
|
||||||
# construct look-up table
|
# construct look-up table
|
||||||
offsets = offsets * 4 + 2 * sizes.numel()
|
offsets = offsets * 4 + 2 * sizes.numel()
|
||||||
|
Reference in New Issue
Block a user