[GENERAL] Minor improvements: (#110)

* Load libcuda.so.1 if libcuda.so is not there. Error if both aren't
there.
* Support for multiple grad_to_none in triton.testing.do_bench
* Benchmark dataframe printed along with name
This commit is contained in:
Philippe Tillet
2021-05-17 19:16:11 -04:00
committed by Philippe Tillet
parent 288b4f7f58
commit 9f30af76fb
4 changed files with 12 additions and 5 deletions

View File

@@ -32,9 +32,9 @@ def bench_op(M, N, dtype, mode, provider):
y = op(x, idx)
dy = torch.randn_like(y)
fn = lambda: y.backward(dy, retain_graph=True)
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=x)
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
if __name__ == '__main__':
bench_op.run('tmp', False)
bench_op.run(print_data=True)