diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index 9c10b88d8..004f236b9 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -26,7 +26,7 @@ def get_simd_tflops(backend, device, num_ctas, num_warps, dtype): def get_tflops(backend, device, num_ctas, num_warps, dtype): cc = _triton.runtime.cc(backend, device) if cc < 80 and dtype == torch.float32: - return get_simd_tflops() + return get_simd_tflops(backend, device, num_ctas, num_warps, dtype) return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)