From bd52e530a027f896ac8f400df3a8894f01a2e265 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 28 Jan 2022 21:40:30 -0800 Subject: [PATCH] [OPS][BLOCKSPARSE] Fix padding issue in DSD LUT (#445) --- python/triton/ops/blocksparse/matmul.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 48efe7ea3..ebcea8e1e 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -311,7 +311,10 @@ def dsd_lut(layout, block, step, trans, device): header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() # create increments incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() - incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) + # pad by a factor 2*MAX_NUM_STAGES + # to accomodate pre-fetching inside the kernel + pad = torch.zeros(20, device=incs.device, dtype=incs.dtype) + incs = torch.cat((incs, pad)) # create lut lut = torch.cat((header, incs)) lut = lut.type(torch.int32).to(device)