[GENERAL] Minor improvements: (#110)

* Load libcuda.so.1 if libcuda.so is not there. Error if both aren't
there.
* Support for multiple grad_to_none in triton.testing.do_bench
* Benchmark dataframe printed along with name
This commit is contained in:
Philippe Tillet
2021-05-17 19:16:11 -04:00
committed by Philippe Tillet
parent 288b4f7f58
commit 9f30af76fb
4 changed files with 12 additions and 5 deletions

View File

@@ -95,8 +95,13 @@ bool dispatch::cuinit(){
if(cuda_==nullptr){ if(cuda_==nullptr){
putenv((char*)"CUDA_CACHE_DISABLE=1"); putenv((char*)"CUDA_CACHE_DISABLE=1");
std::string libcuda = tools::getenv("TRITON_LIBCUDA"); std::string libcuda = tools::getenv("TRITON_LIBCUDA");
if(libcuda.empty()) if(libcuda.empty()){
cuda_ = dlopen("libcuda.so", RTLD_LAZY); cuda_ = dlopen("libcuda.so", RTLD_LAZY);
if(!cuda_)
cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
if(!cuda_)
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
}
else else
cuda_ = dlopen(libcuda.c_str(), RTLD_LAZY); cuda_ = dlopen(libcuda.c_str(), RTLD_LAZY);
} }

View File

@@ -32,9 +32,9 @@ def bench_op(M, N, dtype, mode, provider):
y = op(x, idx) y = op(x, idx)
dy = torch.randn_like(y) dy = torch.randn_like(y)
fn = lambda: y.backward(dy, retain_graph=True) fn = lambda: y.backward(dy, retain_graph=True)
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=x) mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
return gbps(mean_ms), gbps(min_ms), gbps(max_ms) return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
if __name__ == '__main__': if __name__ == '__main__':
bench_op.run('tmp', False) bench_op.run(print_data=True)

View File

@@ -51,7 +51,7 @@ def builtin(fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if 'builder' not in kwargs or \ if 'builder' not in kwargs or \
kwargs['builder'] is None: kwargs['builder'] is None:
raise ValueError("Builder argument must be provided outside of JIT functions") raise ValueError("Builder argument must be provided outside of JIT functions. Did you forget to add @triton.jit ?")
return fn(*args, **kwargs) return fn(*args, **kwargs)
if wrapper.__doc__: if wrapper.__doc__:

View File

@@ -123,7 +123,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
# if it contains a backward pass. So we clear the # if it contains a backward pass. So we clear the
# provided gradients # provided gradients
if grad_to_none is not None: if grad_to_none is not None:
grad_to_none.grad = None for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run # we clear the L2 cache before each run
cache.zero_() cache.zero_()
# record time of `fn` # record time of `fn`
@@ -246,6 +247,7 @@ class Mark:
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
df = df[[bench.x_names[0]] + bench.line_names] df = df[[bench.x_names[0]] + bench.line_names]
if print_data: if print_data:
print(bench.plot_name + ':')
print(df) print(df)
if save_path: if save_path:
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False) df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)