2021-02-08 12:16:41 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import torch
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import triton
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								confs = [
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    triton.testing.Benchmark(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								              x_names = ['N'],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								              x_vals  = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
							 | 
						
					
						
							
								
									
										
										
										
											2021-04-23 17:18:14 -04:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								              line_arg  = 'provider',
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								              line_vals  = ['triton', 'torch'],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								              line_names = ['Triton', 'Torch'],
							 | 
						
					
						
							
								
									
										
										
										
											2021-02-08 12:16:41 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								              ylabel  = 'GBPS',
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								              plot_name = f'{mode}-2048',
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								              args = {'M': 2048, 'dtype': torch.float16, 'mode': mode}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    )\
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    for mode in ['forward', 'backward']
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2021-03-08 20:19:10 -05:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2021-02-08 12:16:41 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								@triton.testing.perf_report(confs)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								def bench_op(M, N, dtype, mode, provider):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # create inputs
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    num_gb = (2 * x.numel() * x.element_size() * 1e-9)
							 | 
						
					
						
							
								
									
										
										
										
											2021-03-08 20:19:10 -05:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    gbps = lambda ms: num_gb / ms * 1e3
							 | 
						
					
						
							
								
									
										
										
										
											2021-02-08 12:16:41 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    # forward pass
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								         'triton': triton.ops.cross_entropy}[provider]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    if mode == 'forward':
							 | 
						
					
						
							
								
									
										
										
										
											2021-03-08 20:19:10 -05:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))
							 | 
						
					
						
							
								
									
										
										
										
											2021-02-08 12:16:41 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    if mode == 'backward':
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        y = op(x, idx)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        dy = torch.randn_like(y)
							 | 
						
					
						
							
								
									
										
										
										
											2021-03-08 20:19:10 -05:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        fn = lambda: y.backward(dy, retain_graph=True)
							 | 
						
					
						
							
								
									
										
										
										
											2021-05-17 19:16:11 -04:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x])
							 | 
						
					
						
							
								
									
										
										
										
											2021-03-08 20:19:10 -05:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    return gbps(mean_ms), gbps(min_ms), gbps(max_ms)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2021-02-08 12:16:41 -08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								if __name__ == '__main__':
							 | 
						
					
						
							
								
									
										
										
										
											2021-05-17 19:16:11 -04:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    bench_op.run(print_data=True)
							 |