[GENERAL] Minor improvements: (#110)
* Load libcuda.so.1 if libcuda.so is not there. Error if both aren't there. * Support for multiple grad_to_none in triton.testing.do_bench * Benchmark dataframe printed along with name
This commit is contained in:
		
				
					committed by
					
						
						Philippe Tillet
					
				
			
			
				
	
			
			
			
						parent
						
							288b4f7f58
						
					
				
				
					commit
					9f30af76fb
				
			@@ -95,8 +95,13 @@ bool dispatch::cuinit(){
 | 
				
			|||||||
  if(cuda_==nullptr){
 | 
					  if(cuda_==nullptr){
 | 
				
			||||||
    putenv((char*)"CUDA_CACHE_DISABLE=1");
 | 
					    putenv((char*)"CUDA_CACHE_DISABLE=1");
 | 
				
			||||||
    std::string libcuda = tools::getenv("TRITON_LIBCUDA");
 | 
					    std::string libcuda = tools::getenv("TRITON_LIBCUDA");
 | 
				
			||||||
    if(libcuda.empty())
 | 
					    if(libcuda.empty()){
 | 
				
			||||||
      cuda_ = dlopen("libcuda.so", RTLD_LAZY);
 | 
					      cuda_ = dlopen("libcuda.so", RTLD_LAZY);
 | 
				
			||||||
 | 
					      if(!cuda_)
 | 
				
			||||||
 | 
					        cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
 | 
				
			||||||
 | 
					      if(!cuda_)
 | 
				
			||||||
 | 
					        throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    else
 | 
					    else
 | 
				
			||||||
      cuda_ = dlopen(libcuda.c_str(), RTLD_LAZY);
 | 
					      cuda_ = dlopen(libcuda.c_str(), RTLD_LAZY);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -32,9 +32,9 @@ def bench_op(M, N, dtype, mode, provider):
 | 
				
			|||||||
        y = op(x, idx)
 | 
					        y = op(x, idx)
 | 
				
			||||||
        dy = torch.randn_like(y)
 | 
					        dy = torch.randn_like(y)
 | 
				
			||||||
        fn = lambda: y.backward(dy, retain_graph=True)
 | 
					        fn = lambda: y.backward(dy, retain_graph=True)
 | 
				
			||||||
        mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=x)
 | 
					        mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
 | 
				
			||||||
    return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
 | 
					    return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    bench_op.run('tmp', False)
 | 
					    bench_op.run(print_data=True)
 | 
				
			||||||
@@ -51,7 +51,7 @@ def builtin(fn):
 | 
				
			|||||||
    def wrapper(*args, **kwargs):
 | 
					    def wrapper(*args, **kwargs):
 | 
				
			||||||
        if 'builder' not in kwargs or \
 | 
					        if 'builder' not in kwargs or \
 | 
				
			||||||
           kwargs['builder'] is None:
 | 
					           kwargs['builder'] is None:
 | 
				
			||||||
            raise ValueError("Builder argument must be provided outside of JIT functions")
 | 
					            raise ValueError("Builder argument must be provided outside of JIT functions. Did you forget to add @triton.jit ?")
 | 
				
			||||||
        return fn(*args, **kwargs)
 | 
					        return fn(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if wrapper.__doc__:
 | 
					    if wrapper.__doc__:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -123,7 +123,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):
 | 
				
			|||||||
        # if it contains a backward pass. So we clear the
 | 
					        # if it contains a backward pass. So we clear the
 | 
				
			||||||
        # provided gradients
 | 
					        # provided gradients
 | 
				
			||||||
        if grad_to_none is not None:
 | 
					        if grad_to_none is not None:
 | 
				
			||||||
            grad_to_none.grad = None
 | 
					            for x in grad_to_none:
 | 
				
			||||||
 | 
					                x.grad = None
 | 
				
			||||||
        # we clear the L2 cache before each run
 | 
					        # we clear the L2 cache before each run
 | 
				
			||||||
        cache.zero_()
 | 
					        cache.zero_()
 | 
				
			||||||
        # record time of `fn`
 | 
					        # record time of `fn`
 | 
				
			||||||
@@ -246,6 +247,7 @@ class Mark:
 | 
				
			|||||||
                plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
 | 
					                plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
 | 
				
			||||||
        df = df[[bench.x_names[0]] + bench.line_names]
 | 
					        df = df[[bench.x_names[0]] + bench.line_names]
 | 
				
			||||||
        if print_data:
 | 
					        if print_data:
 | 
				
			||||||
 | 
					            print(bench.plot_name + ':')
 | 
				
			||||||
            print(df)
 | 
					            print(df)
 | 
				
			||||||
        if save_path:
 | 
					        if save_path:
 | 
				
			||||||
            df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
 | 
					            df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user