[OPS][BLOCKSPARSE] Fix padding issue in DSD LUT (#445)
This commit is contained in:
@@ -311,7 +311,10 @@ def dsd_lut(layout, block, step, trans, device):
|
|||||||
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
|
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
|
||||||
# create increments
|
# create increments
|
||||||
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
|
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
|
||||||
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
|
# pad by a factor 2*MAX_NUM_STAGES
|
||||||
|
# to accomodate pre-fetching inside the kernel
|
||||||
|
pad = torch.zeros(20, device=incs.device, dtype=incs.dtype)
|
||||||
|
incs = torch.cat((incs, pad))
|
||||||
# create lut
|
# create lut
|
||||||
lut = torch.cat((header, incs))
|
lut = torch.cat((header, incs))
|
||||||
lut = lut.type(torch.int32).to(device)
|
lut = lut.type(torch.int32).to(device)
|
||||||
|
Reference in New Issue
Block a user