[OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790)
This commit is contained in:
@@ -113,13 +113,13 @@ func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
|
||||
%zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32>
|
||||
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
|
||||
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
|
||||
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32>
|
||||
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
|
||||
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
|
||||
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
|
||||
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
|
||||
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32>
|
||||
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
|
||||
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
|
||||
|
||||
%ptr128x128 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
%ptr32x32 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>>
|
||||
|
@@ -5,7 +5,7 @@ func @ops() {
|
||||
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
|
||||
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
|
||||
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
|
||||
%0 = tt.dot %a, %b, %c {allowTF32 = true} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
|
||||
%0 = tt.dot %a, %b, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -662,7 +662,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||
%D = tt.dot %AA, %BB, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0>
|
||||
%D = tt.dot %AA, %BB, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0>
|
||||
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user