diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index b776b3dbf..30864f391 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -5,7 +5,7 @@ import triton def rounded_linspace(low, high, steps, div): ret = torch.linspace(low, high, steps) - ret = (ret.int() + div - 1) // div * div + ret = torch.div(ret.int() + div - 1, div, rounding_mode='trunc') * div ret = torch.unique(ret) return list(map(int, ret))