* Load libcuda.so.1 if libcuda.so is not there. Error if both aren't there. * Support for multiple grad_to_none in triton.testing.do_bench * Benchmark dataframe printed along with name
		
			
				
	
	
		
			40 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			40 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
import triton
 | 
						|
 | 
						|
confs = [
 | 
						|
    triton.testing.Benchmark(
 | 
						|
              x_names = ['N'],
 | 
						|
              x_vals  = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
 | 
						|
              line_arg  = 'provider',
 | 
						|
              line_vals  = ['triton', 'torch'],
 | 
						|
              line_names = ['Triton', 'Torch'],
 | 
						|
              ylabel  = 'GBPS',
 | 
						|
              plot_name = f'{mode}-2048',
 | 
						|
              args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
 | 
						|
    )\
 | 
						|
    for mode in ['forward', 'backward']
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
@triton.testing.perf_report(confs)
 | 
						|
def bench_op(M, N, dtype, mode, provider):
 | 
						|
    # create inputs
 | 
						|
    x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
 | 
						|
    idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
 | 
						|
    num_gb = (2 * x.numel() * x.element_size() * 1e-9)
 | 
						|
    gbps = lambda ms: num_gb / ms * 1e3
 | 
						|
    # forward pass
 | 
						|
    op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
 | 
						|
         'triton': triton.ops.cross_entropy}[provider]
 | 
						|
    if mode == 'forward':
 | 
						|
        mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
 | 
						|
    if mode == 'backward':
 | 
						|
        y = op(x, idx)
 | 
						|
        dy = torch.randn_like(y)
 | 
						|
        fn = lambda: y.backward(dy, retain_graph=True)
 | 
						|
        mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
 | 
						|
    return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    bench_op.run(print_data=True) |