[FRONTEND] Add missing args to get_simd_tflops() (#578)

This commit is contained in:
Jason Ansel
2022-07-11 14:37:59 -07:00
committed by GitHub
parent 4a399a7e40
commit c9a2b9c7d4

View File

@@ -26,7 +26,7 @@ def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
def get_tflops(backend, device, num_ctas, num_warps, dtype):
cc = _triton.runtime.cc(backend, device)
if cc < 80 and dtype == torch.float32:
return get_simd_tflops()
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)