diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index f3a1348e0..c7fcfd2a2 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -57,7 +57,7 @@ class _matmul(torch.autograd.Function): 'TN' : _matmul.TN, 'TK' : _matmul.TK, 'TZ' : _matmul.TZ, - 'IS_TK_DIV_K' : is_tk_div_k + 'IS_TK_DIV_K' : int(is_tk_div_k) } _matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines) kernel = _matmul._kernels[key]