[PYTHON] Fixed issue with IS_TK_DIV_K

This commit is contained in:
Philippe Tillet
2021-01-31 13:36:19 -05:00
parent 7cf358a352
commit 52af8cda34

View File

@@ -57,7 +57,7 @@ class _matmul(torch.autograd.Function):
'TN' : _matmul.TN,
'TK' : _matmul.TK,
'TZ' : _matmul.TZ,
'IS_TK_DIV_K' : is_tk_div_k
'IS_TK_DIV_K' : int(is_tk_div_k)
}
_matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines)
kernel = _matmul._kernels[key]