From 268d2cd18d3168f423803f1e54097c0a51fd52de Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 4 Jan 2023 17:08:08 -0800 Subject: [PATCH] better convert + write-back --- python/bwd.ttgir | 8 ++++---- python/tutorials/06-fused-attention.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/bwd.ttgir b/python/bwd.ttgir index 5ac495943..d3f9aeccc 100644 --- a/python/bwd.ttgir +++ b/python/bwd.ttgir @@ -154,13 +154,13 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> scf.yield %113, %130, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> } - %80 = triton_gpu.convert_layout %79#1 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked1> - %81 = triton_gpu.convert_layout %79#0 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked1> %82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %83 = arith.truncf %81 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1> + %81 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> + %83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> tt.store %82, %83 : tensor<128x64xf16, #blocked1> %84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %85 = arith.truncf %80 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1> + %80 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> + %85 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> tt.store %84, %85 : tensor<128x64xf16, #blocked1> } return diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index f29e1eeaf..2da77fbc6 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -331,7 +331,7 @@ BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 configs = [triton.testing.Benchmark( x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 15)], + x_vals=[2**i for i in range(10, 13)], line_arg='provider', line_vals=['triton'], line_names=['Triton'],