[OPS][BLOCKSPARSE] Fix padding issue in DSD LUT (#445)

This commit is contained in:
Philippe Tillet
2022-01-28 21:40:30 -08:00
committed by GitHub
parent e68d6a7776
commit bd52e530a0

View File

@@ -311,7 +311,10 @@ def dsd_lut(layout, block, step, trans, device):
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
# create increments
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# pad by a factor 2*MAX_NUM_STAGES
# to accomodate pre-fetching inside the kernel
pad = torch.zeros(20, device=incs.device, dtype=incs.dtype)
incs = torch.cat((incs, pad))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)