reorder conversions to dot operand

This commit is contained in:
Phil Tillet
2023-01-09 20:11:22 -08:00
parent c98c889d7f
commit bae4c40379
2 changed files with 26 additions and 14 deletions

View File

@@ -64,7 +64,6 @@ public:
Operation* argOp = op.getOperand().getDefiningOp();
if(!argOp)
return;
llvm::outs() << "moving " << *op << "\n";
op->moveAfter(argOp);
});
// Move transpositions just after their definition
@@ -75,7 +74,20 @@ public:
return;
op->moveAfter(argOp);
});
// Move `dot` operand so that conversions to opIdx=0 happens before conversions to opIdx=1
m.walk([&](triton::gpu::ConvertLayoutOp op){
auto dstType = op.getResult().getType().cast<RankedTensorType>();
auto dstEncoding = dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if(!dstEncoding)
return;
int opIdx = dstEncoding.getOpIdx();
if(opIdx != 1)
return;
if(op->getUsers().empty())
return;
auto user_begin = op->user_begin();
op->moveBefore(*user_begin);
});
return;
}
};

View File

@@ -101,9 +101,9 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%89 = arith.addi %87, %14 : tensor<128xi32, #blocked0>
%90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%91 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%92 = triton_gpu.convert_layout %61 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%93 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%94 = tt.dot %93, %92, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%93 = triton_gpu.convert_layout %61 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%94 = tt.dot %92, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%95 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
%97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
@@ -131,22 +131,22 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%119 = tt.expand_dims %118 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%120 = tt.broadcast %119 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%121 = arith.subf %cst_1, %120 : tensor<128x128xf32, #mma0>
%122 = triton_gpu.convert_layout %68 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%123 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%124 = tt.dot %123, %122, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%122 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%123 = triton_gpu.convert_layout %68 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%124 = tt.dot %122, %123, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%125 = arith.mulf %107, %124 : tensor<128x128xf32, #mma0>
%126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0>
%127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
%129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
%130 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%131 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%132 = tt.dot %131, %130, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%130 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%131 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%132 = tt.dot %130, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
%134 = triton_gpu.convert_layout %133 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
%135 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%136 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%137 = tt.dot %136, %135, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%135 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%136 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%137 = tt.dot %135, %136, %134 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%138 = triton_gpu.convert_layout %137 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
tt.store %arg29, %138 : tensor<128x64xf32, #blocked2>
%139 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>