better convert + write-back
This commit is contained in:
@@ -154,13 +154,13 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
|
|||||||
%139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
scf.yield %113, %130, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
scf.yield %113, %130, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
}
|
}
|
||||||
%80 = triton_gpu.convert_layout %79#1 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked1>
|
|
||||||
%81 = triton_gpu.convert_layout %79#0 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked1>
|
|
||||||
%82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%83 = arith.truncf %81 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
|
%81 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
||||||
|
%83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
||||||
tt.store %82, %83 : tensor<128x64xf16, #blocked1>
|
tt.store %82, %83 : tensor<128x64xf16, #blocked1>
|
||||||
%84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
%85 = arith.truncf %80 : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
|
%80 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
|
||||||
|
%85 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
|
||||||
tt.store %84, %85 : tensor<128x64xf16, #blocked1>
|
tt.store %84, %85 : tensor<128x64xf16, #blocked1>
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@@ -331,7 +331,7 @@ BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
|||||||
# vary seq length for fixed head and batch=4
|
# vary seq length for fixed head and batch=4
|
||||||
configs = [triton.testing.Benchmark(
|
configs = [triton.testing.Benchmark(
|
||||||
x_names=['N_CTX'],
|
x_names=['N_CTX'],
|
||||||
x_vals=[2**i for i in range(10, 15)],
|
x_vals=[2**i for i in range(10, 13)],
|
||||||
line_arg='provider',
|
line_arg='provider',
|
||||||
line_vals=['triton'],
|
line_vals=['triton'],
|
||||||
line_names=['Triton'],
|
line_names=['Triton'],
|
||||||
|
Reference in New Issue
Block a user