[OPS][BLOCKSPARSE] Fix padding issue in DSD LUT (#445)
This commit is contained in:
@@ -311,7 +311,10 @@ def dsd_lut(layout, block, step, trans, device):
|
||||
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
|
||||
# create increments
|
||||
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
|
||||
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
|
||||
# pad by a factor 2*MAX_NUM_STAGES
|
||||
# to accomodate pre-fetching inside the kernel
|
||||
pad = torch.zeros(20, device=incs.device, dtype=incs.dtype)
|
||||
incs = torch.cat((incs, pad))
|
||||
# create lut
|
||||
lut = torch.cat((header, incs))
|
||||
lut = lut.type(torch.int32).to(device)
|
||||
|
Reference in New Issue
Block a user