[GENERAL] Minor improvements: (#110)
* Load libcuda.so.1 if libcuda.so is not there. Error if both aren't there. * Support for multiple grad_to_none in triton.testing.do_bench * Benchmark dataframe printed along with name
This commit is contained in:
committed by
Philippe Tillet
parent
288b4f7f58
commit
9f30af76fb
@@ -32,9 +32,9 @@ def bench_op(M, N, dtype, mode, provider):
|
||||
y = op(x, idx)
|
||||
dy = torch.randn_like(y)
|
||||
fn = lambda: y.backward(dy, retain_graph=True)
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=x)
|
||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
|
||||
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
bench_op.run('tmp', False)
|
||||
bench_op.run(print_data=True)
|
Reference in New Issue
Block a user