[FRONTEND] Add scalar type support for some ops (#661)

This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
Shintaro Iwasaki
2022-09-15 16:12:52 -07:00
committed by GitHub
parent 2e08450c80
commit 43be75ad42
27 changed files with 203 additions and 129 deletions

View File

@@ -28,20 +28,20 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%12 = tt.getelementptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
%13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
%14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2>
%15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
return