[PYTHON] Bugfix on FP32 blocksparse matmul

This commit is contained in:
Philippe Tillet
2021-02-24 14:43:08 -05:00
parent 567a1a3d17
commit f64b779b0d

View File

@@ -362,7 +362,7 @@ class matmul:
return self.lut_cache[key] return self.lut_cache[key]
# C look-up table # C look-up table
layout, block = self.layout, self.block layout, block = self.layout, self.block
step = 8 if dtype == torch.float32 else 16 step = 16
if self.mode == 'sdd': if self.mode == 'sdd':
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device) c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dsd': elif self.mode == 'dsd':