[Triton-MLIR] Add ex2.approx implementation for ExpOp and fix smem allocation for ReduceOpConversion (#875)

This commit is contained in:
Qingyi Liu
2022-11-15 09:27:32 +08:00
committed by GitHub
parent c28cfd821b
commit 4c4159c6fa
5 changed files with 58 additions and 0 deletions

View File

@@ -9,6 +9,8 @@
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_loop
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
@@ -313,3 +315,5 @@ func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
return
// CHECK-NEXT: size = 40960
}
}

View File

@@ -9,6 +9,8 @@
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_loop
// There shouldn't be any membar with the dot op encoding.
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
@@ -250,3 +252,5 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
}
}