[GENERAL] Minor improvements: (#110)
* Load libcuda.so.1 if libcuda.so is not there. Error if both aren't there. * Support for multiple grad_to_none in triton.testing.do_bench * Benchmark dataframe printed along with name
This commit is contained in:
committed by
Philippe Tillet
parent
288b4f7f58
commit
9f30af76fb
@@ -95,8 +95,13 @@ bool dispatch::cuinit(){
|
|||||||
if(cuda_==nullptr){
|
if(cuda_==nullptr){
|
||||||
putenv((char*)"CUDA_CACHE_DISABLE=1");
|
putenv((char*)"CUDA_CACHE_DISABLE=1");
|
||||||
std::string libcuda = tools::getenv("TRITON_LIBCUDA");
|
std::string libcuda = tools::getenv("TRITON_LIBCUDA");
|
||||||
if(libcuda.empty())
|
if(libcuda.empty()){
|
||||||
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
||||||
|
if(!cuda_)
|
||||||
|
cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
|
||||||
|
if(!cuda_)
|
||||||
|
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
|
||||||
|
}
|
||||||
else
|
else
|
||||||
cuda_ = dlopen(libcuda.c_str(), RTLD_LAZY);
|
cuda_ = dlopen(libcuda.c_str(), RTLD_LAZY);
|
||||||
}
|
}
|
||||||
|
@@ -32,9 +32,9 @@ def bench_op(M, N, dtype, mode, provider):
|
|||||||
y = op(x, idx)
|
y = op(x, idx)
|
||||||
dy = torch.randn_like(y)
|
dy = torch.randn_like(y)
|
||||||
fn = lambda: y.backward(dy, retain_graph=True)
|
fn = lambda: y.backward(dy, retain_graph=True)
|
||||||
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=x)
|
mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
|
||||||
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
bench_op.run('tmp', False)
|
bench_op.run(print_data=True)
|
@@ -51,7 +51,7 @@ def builtin(fn):
|
|||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
if 'builder' not in kwargs or \
|
if 'builder' not in kwargs or \
|
||||||
kwargs['builder'] is None:
|
kwargs['builder'] is None:
|
||||||
raise ValueError("Builder argument must be provided outside of JIT functions")
|
raise ValueError("Builder argument must be provided outside of JIT functions. Did you forget to add @triton.jit ?")
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
if wrapper.__doc__:
|
if wrapper.__doc__:
|
||||||
|
@@ -123,7 +123,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
|
|||||||
# if it contains a backward pass. So we clear the
|
# if it contains a backward pass. So we clear the
|
||||||
# provided gradients
|
# provided gradients
|
||||||
if grad_to_none is not None:
|
if grad_to_none is not None:
|
||||||
grad_to_none.grad = None
|
for x in grad_to_none:
|
||||||
|
x.grad = None
|
||||||
# we clear the L2 cache before each run
|
# we clear the L2 cache before each run
|
||||||
cache.zero_()
|
cache.zero_()
|
||||||
# record time of `fn`
|
# record time of `fn`
|
||||||
@@ -246,6 +247,7 @@ class Mark:
|
|||||||
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
|
||||||
df = df[[bench.x_names[0]] + bench.line_names]
|
df = df[[bench.x_names[0]] + bench.line_names]
|
||||||
if print_data:
|
if print_data:
|
||||||
|
print(bench.plot_name + ':')
|
||||||
print(df)
|
print(df)
|
||||||
if save_path:
|
if save_path:
|
||||||
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
||||||
|
Reference in New Issue
Block a user