better convert + write-back
This commit is contained in:
@@ -154,13 +154,13 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
||||
%139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||
scf.yield %113, %130, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||
}
|
||||
%80 = triton_gpu.convert_layout %79#1 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked1>
|
||||
%81 = triton_gpu.convert_layout %79#0 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked1>
|
||||
%82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||
%83 = arith.truncf %81 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
|
||||
%81 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
||||
%83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
||||
tt.store %82, %83 : tensor<128x64xf16, #blocked1>
|
||||
%84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||
%85 = arith.truncf %80 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
|
||||
%80 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
||||
%85 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
||||
tt.store %84, %85 : tensor<128x64xf16, #blocked1>
|
||||
}
|
||||
return
|
||||
|
@@ -331,7 +331,7 @@ BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 15)],
|
||||
x_vals=[2**i for i in range(10, 13)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'],
|
||||
line_names=['Triton'],
|
||||
|
Reference in New Issue
Block a user