diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 48efe7ea3..ebcea8e1e 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -311,7 +311,10 @@ def dsd_lut(layout, block, step, trans, device): header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() # create increments incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() - incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) + # pad by a factor 2*MAX_NUM_STAGES + # to accomodate pre-fetching inside the kernel + pad = torch.zeros(20, device=incs.device, dtype=incs.dtype) + incs = torch.cat((incs, pad)) # create lut lut = torch.cat((header, incs)) lut = lut.type(torch.int32).to(device)