Added verifier for trans

This commit is contained in:
Phil Tillet
2023-01-08 14:29:17 -08:00
parent 42421fabc5
commit 6c750b6856
9 changed files with 243 additions and 200 deletions

View File

@@ -25,6 +25,10 @@ class DialectInferLayoutInterface
public: public:
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
virtual LogicalResult
inferTransOpEncoding(Attribute operandEncoding,
Attribute &resultEncoding) const = 0;
virtual LogicalResult virtual LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0; Attribute &resultEncoding) const = 0;

View File

@@ -289,7 +289,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
} }
def TT_TransOp : TT_Op<"trans", [NoSideEffect, def TT_TransOp : TT_Op<"trans", [NoSideEffect,
SameOperandsAndResultElementType]> { DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "transpose a tensor"; let summary = "transpose a tensor";

View File

@@ -319,20 +319,7 @@ struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType, src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
src); src);
} }
auto srcSharedEncoding = rewriter.replaceOpWithNewOp<triton::TransOp>(op, src);
srcEncoding.cast<triton::gpu::SharedEncodingAttr>();
SmallVector<unsigned> retOrder(srcSharedEncoding.getOrder().begin(),
srcSharedEncoding.getOrder().end());
SmallVector<int64_t> retShapes(srcType.getShape().begin(),
srcType.getShape().end());
std::reverse(retOrder.begin(), retOrder.end());
std::reverse(retShapes.begin(), retShapes.end());
auto retEncoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder);
auto retType =
RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding);
rewriter.replaceOpWithNewOp<triton::TransOp>(op, retType, src);
return success(); return success();
} }
}; };

View File

@@ -206,6 +206,36 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
state.addTypes({resultType}); state.addTypes({resultType});
} }
//-- TransOp --
mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// type is the same as the input
auto argTy = operands[0].getType().cast<RankedTensorType>();
SmallVector<int64_t> retShape(argTy.getShape().begin(),
argTy.getShape().end());
std::reverse(retShape.begin(), retShape.end());
auto retEltTy = argTy.getElementType();
Attribute argEncoding = argTy.getEncoding();
Attribute retEncoding;
if (argEncoding) {
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface =
dyn_cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferTransOpEncoding(argEncoding, retEncoding)
.failed()) {
llvm::report_fatal_error("failed to infer layout for ReduceOp");
return mlir::failure();
}
}
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, retEltTy, retEncoding));
return mlir::success();
}
//-- DotOp -- //-- DotOp --
mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes( mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands, MLIRContext *context, Optional<Location> location, ValueRange operands,

View File

@@ -737,6 +737,22 @@ struct TritonGPUInferLayoutInterface
return success(); return success();
} }
LogicalResult
inferTransOpEncoding(Attribute operandEncoding, Attribute &resultEncoding) const {
SharedEncodingAttr sharedEncoding = operandEncoding.dyn_cast<SharedEncodingAttr>();
if(!sharedEncoding)
return failure();
SmallVector<unsigned> retOrder(sharedEncoding.getOrder().begin(),
sharedEncoding.getOrder().end());
std::reverse(retOrder.begin(), retOrder.end());
resultEncoding = SharedEncodingAttr::get(getDialect()->getContext(),
sharedEncoding.getVec(),
sharedEncoding.getPerPhase(),
sharedEncoding.getMaxPhase(),
retOrder);
return mlir::success();
}
LogicalResult LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding, Attribute &resultEncoding,

View File

@@ -3,15 +3,14 @@
// TODO: reuse %128 in %137 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> // TODO: reuse %128 in %137 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
// don't convert loaded value to mma for accumulation // don't convert loaded value to mma for accumulation
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}> #mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}>
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}> #mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> #shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> #shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 8 : i32} { module attributes {"triton_gpu.num-warps" = 8 : i32} {
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) { func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
@@ -21,7 +20,6 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1> %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0> %cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0> %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1>
%0 = tt.get_program_id {axis = 0 : i32} : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.divsi %0, %arg22 : i32 %1 = arith.divsi %0, %arg22 : i32
%2 = arith.remsi %0, %arg22 : i32 %2 = arith.remsi %0, %arg22 : i32
@@ -84,94 +82,94 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> %60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%61 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1> %61 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1>
%62 = tt.broadcast %61 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> %62 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
%63 = arith.addi %62, %24 : tensor<128x64xi32, #blocked1> %63 = tt.broadcast %62 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
%64 = tt.addptr %30, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %64 = arith.addi %63, %24 : tensor<128x64xi32, #blocked1>
%65 = tt.load %64 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %65 = tt.addptr %30, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%66 = triton_gpu.convert_layout %65 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> %66 = tt.load %65 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%67 = arith.index_cast %47 : i32 to index %67 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1>
%68 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> %68 = arith.index_cast %47 : i32 to index
%69 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> %69 = tt.trans %61 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared2>
%70 = tt.expand_dims %69 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> %70 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
%71 = tt.broadcast %70 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> %71 = tt.expand_dims %70 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
%72 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> %72 = tt.broadcast %71 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
%73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> %73 = tt.trans %67 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared2>
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> %74 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
%75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2> %75 = tt.broadcast %74 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
%76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2> %76 = arith.addi %75, %26 : tensor<128x64xi32, #blocked2>
%77 = tt.addptr %27, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %77 = tt.addptr %32, %76 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
%78 = tt.addptr %31, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %78 = tt.addptr %27, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%79:5 = scf.for %arg26 = %67 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) { %79 = tt.addptr %31, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%86 = arith.index_cast %arg26 : index to i32 %80:5 = scf.for %arg26 = %68 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %77, %arg30 = %78, %arg31 = %79) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
%87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0> %87 = arith.index_cast %arg26 : index to i32
%88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %88 = tt.splat %87 : (i32) -> tensor<128xi32, #blocked0>
%89 = arith.addi %87, %14 : tensor<128xi32, #blocked0> %89 = tt.splat %87 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %90 = arith.addi %88, %14 : tensor<128xi32, #blocked0>
%91 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> %91 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> %92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%93 = triton_gpu.convert_layout %68 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %93 = triton_gpu.convert_layout %69 : (tensor<64x128xf16, #shared2>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%94 = tt.dot %92, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> %94 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%95 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %95 = tt.dot %94, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> %96 = arith.addi %89, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0> %97 = tt.expand_dims %96 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
%98 = "triton_gpu.cmpi"(%97, %71) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> %98 = tt.broadcast %97 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
%99 = "triton_gpu.select"(%98, %94, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0> %99 = "triton_gpu.cmpi"(%98, %72) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
%100 = tt.addptr %38, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0> %100 = "triton_gpu.select"(%99, %95, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> %101 = tt.addptr %38, %90 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%102 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %102 = tt.load %101 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%103 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0> %103 = triton_gpu.convert_layout %102 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%104 = tt.expand_dims %102 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> %104 = arith.mulf %100, %39 : tensor<128x128xf32, #mma0>
%105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> %105 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%106 = arith.subf %103, %105 : tensor<128x128xf32, #mma0> %106 = tt.broadcast %105 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%107 = math.exp %106 : tensor<128x128xf32, #mma0> %107 = arith.subf %104, %106 : tensor<128x128xf32, #mma0>
%108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %108 = math.exp %107 : tensor<128x128xf32, #mma0>
%109 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> %109 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> %110 = triton_gpu.convert_layout %109 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> %111 = arith.truncf %108 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%112 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> %112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1>
%113 = triton_gpu.convert_layout %112 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %113 = tt.trans %112 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared2>
%114 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %114 = triton_gpu.convert_layout %113 : (tensor<128x128xf16, #shared2>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%115 = tt.dot %114, %113, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %115 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%116 = tt.addptr %40, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0> %116 = tt.dot %114, %115, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%117 = tt.load %116 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> %117 = tt.addptr %40, %90 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%118 = triton_gpu.convert_layout %117 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %118 = tt.load %117 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%119 = tt.expand_dims %118 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> %119 = triton_gpu.convert_layout %118 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%120 = tt.broadcast %119 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> %120 = tt.expand_dims %119 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%121 = arith.subf %cst_1, %120 : tensor<128x128xf32, #mma0> %121 = tt.broadcast %120 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%122 = triton_gpu.convert_layout %112 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> %122 = arith.subf %cst_1, %121 : tensor<128x128xf32, #mma0>
%123 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %123 = triton_gpu.convert_layout %110 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%124 = tt.dot %122, %123, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> %124 = triton_gpu.convert_layout %73 : (tensor<64x128xf16, #shared2>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%125 = arith.mulf %107, %124 : tensor<128x128xf32, #mma0> %125 = tt.dot %123, %124, %122 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%126 = arith.mulf %125, %39 : tensor<128x128xf32, #mma0> %126 = arith.mulf %108, %125 : tensor<128x128xf32, #mma0>
%127 = arith.truncf %126 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> %127 = arith.mulf %126, %39 : tensor<128x128xf32, #mma0>
%128 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> %128 = arith.truncf %127 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%129 = tt.trans %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> %129 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared1>
%130 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %130 = tt.trans %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #shared2>
%131 = triton_gpu.convert_layout %129 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %131 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%132 = tt.dot %131, %130, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %132 = triton_gpu.convert_layout %130 : (tensor<128x128xf16, #shared2>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%133 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> %133 = tt.dot %132, %131, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%135 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> %134 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
%136 = triton_gpu.convert_layout %135 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %135 = triton_gpu.convert_layout %134 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
%137 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %136 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%138 = tt.dot %137, %136, %cst_2 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %137 = triton_gpu.convert_layout %60 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%138 = tt.dot %136, %137, %135 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%139 = triton_gpu.convert_layout %138 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> %139 = triton_gpu.convert_layout %138 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
%1000 = arith.addf %133, %139: tensor<128x64xf32, #blocked2> tt.store %arg29, %139 : tensor<128x64xf32, #blocked2>
tt.store %arg29, %133 : tensor<128x64xf32, #blocked2>
%140 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2> %140 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
%141 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %141 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%142 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %142 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
scf.yield %115, %132, %140, %141, %142 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1> scf.yield %116, %133, %140, %141, %142 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
} }
%80 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> %81 = arith.truncf %80#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
%81 = tt.addptr %44, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %82 = tt.addptr %44, %64 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%82 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> %83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
tt.store %81, %82 : tensor<128x64xf16, #blocked1> tt.store %82, %83 : tensor<128x64xf16, #blocked1>
%83 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> %84 = arith.truncf %80#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
%84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %85 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%85 = triton_gpu.convert_layout %83 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> %86 = triton_gpu.convert_layout %84 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
tt.store %84, %85 : tensor<128x64xf16, #blocked1> tt.store %85, %86 : tensor<128x64xf16, #blocked1>
} }
return return
} }
} }

View File

@@ -908,6 +908,7 @@ def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
# pm.add_tritongpu_optimize_load_convert_pass() # pm.add_tritongpu_optimize_load_convert_pass()
pm.add_tritongpu_sink_conversions_from_shared_pass() pm.add_tritongpu_sink_conversions_from_shared_pass()
pm.add_tritongpu_decompose_conversions_to_dot_operand_pass() pm.add_tritongpu_decompose_conversions_to_dot_operand_pass()
pm.add_cse_pass()
pm.run(mod) pm.run(mod)
return mod return mod

View File

@@ -191,7 +191,8 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv) tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk) tl.store(dk_ptrs, dk)
_bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./unoptimized.ttgir", num_warps=8)
# _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8)
# _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432) # _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432)
@@ -259,36 +260,36 @@ class _attention(torch.autograd.Function):
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
) )
_bwd_kernel[(ctx.grid[1],1,1)]( # _bwd_kernel[(ctx.grid[1],1,1)](
q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, # q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale,
o.data_ptr(), do_scaled.data_ptr(), # o.data_ptr(), do_scaled.data_ptr(),
dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), # dq.data_ptr(), dk.data_ptr(), dv.data_ptr(),
l.data_ptr(), m.data_ptr(), # l.data_ptr(), m.data_ptr(),
delta.data_ptr(), # delta.data_ptr(),
q.stride(0), q.stride(1), q.stride(2), # q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2), # k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2), # v.stride(0), v.stride(1), v.stride(2),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0]
)
# pgm = _bwd_kernel[(ctx.grid[1],)](
# q, k, v, ctx.sm_scale,
# o, do_scaled,
# dq, dk, dv,
# l, m,
# delta,
# q.stride(0), q.stride(1), q.stride(2), q.stride(3),
# k.stride(0), k.stride(1), k.stride(2), k.stride(3),
# v.stride(0), v.stride(1), v.stride(2), v.stride(3),
# q.shape[0], q.shape[1], q.shape[2], # q.shape[0], q.shape[1], q.shape[2],
# ctx.grid[0], # ctx.grid[0]
# BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
# BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
# num_stages=1,
# ) # )
# print(pgm.asm["ttgir"])
# exit() pgm = _bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
print(pgm.asm["ttgir"])
exit()
return dq, dk, dv, None return dq, dk, dv, None

View File

@@ -11,6 +11,7 @@
#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}> #mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> #shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 8 : i32} { module attributes {"triton_gpu.num-warps" = 8 : i32} {
func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) { func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) {
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
@@ -81,91 +82,96 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {
%57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1> %57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1>
%58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%60 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1> %60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%61 = tt.broadcast %60 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> %61 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1>
%62 = arith.addi %61, %24 : tensor<128x64xi32, #blocked1> %62 = tt.broadcast %61 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
%63 = tt.addptr %30, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %63 = arith.addi %62, %24 : tensor<128x64xi32, #blocked1>
%64 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %64 = tt.addptr %30, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%65 = arith.index_cast %47 : i32 to index %65 = tt.load %64 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> %66 = triton_gpu.convert_layout %65 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0>
%67 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> %67 = arith.index_cast %47 : i32 to index
%68 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> %68 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
%69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> %69 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>
%70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> %70 = tt.expand_dims %69 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0>
%71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> %71 = tt.broadcast %70 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0>
%72 = tt.trans %71 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> %72 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1>
%73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> %73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2>
%74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> %74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2>
%75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2> %75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2>
%76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2> %76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
%77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %77 = tt.addptr %27, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %78 = tt.addptr %31, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%79 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %79:5 = scf.for %arg26 = %67 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) {
%80:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>) { %86 = arith.index_cast %arg26 : index to i32
%87 = arith.index_cast %arg26 : index to i32 %87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0>
%88 = tt.splat %87 : (i32) -> tensor<128xi32, #blocked0> %88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%89 = tt.splat %87 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %89 = arith.addi %87, %14 : tensor<128xi32, #blocked0>
%90 = arith.addi %88, %14 : tensor<128xi32, #blocked0> %90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%91 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %91 = triton_gpu.convert_layout %68 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> %92 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2>
%93 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %93 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%94 = tt.dot %92, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> %94 = tt.dot %93, %91, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%95 = arith.addi %89, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %95 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> %96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>
%97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0> %97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0>
%98 = "triton_gpu.cmpi"(%97, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> %98 = "triton_gpu.cmpi"(%97, %71) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0>
%99 = "triton_gpu.select"(%98, %94, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0> %99 = "triton_gpu.select"(%98, %94, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%100 = tt.addptr %38, %90 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0> %100 = tt.addptr %38, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> %101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%102 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0> %102 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%103 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %103 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0>
%104 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> %104 = tt.expand_dims %102 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> %105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%106 = arith.subf %102, %105 : tensor<128x128xf32, #mma0> %106 = arith.subf %103, %105 : tensor<128x128xf32, #mma0>
%107 = math.exp %106 : tensor<128x128xf32, #mma0> %107 = math.exp %106 : tensor<128x128xf32, #mma0>
%108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%109 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> %109 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> %110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
%111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> %111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
%112 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%113 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %113 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2>
%114 = tt.dot %113, %112, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %114 = triton_gpu.convert_layout %113 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%115 = tt.addptr %40, %90 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0> %115 = tt.dot %112, %114, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%116 = tt.load %115 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> %116 = tt.addptr %40, %89 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
%117 = triton_gpu.convert_layout %116 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %117 = tt.load %116 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0>
%118 = tt.expand_dims %117 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> %118 = triton_gpu.convert_layout %117 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>
%119 = tt.broadcast %118 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> %119 = tt.expand_dims %118 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0>
%120 = arith.subf %cst_1, %119 : tensor<128x128xf32, #mma0> %120 = tt.broadcast %119 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0>
%121 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> %121 = arith.subf %cst_1, %120 : tensor<128x128xf32, #mma0>
%122 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %122 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>>
%123 = tt.dot %121, %122, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> %123 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2>
%124 = arith.mulf %107, %123 : tensor<128x128xf32, #mma0> %124 = triton_gpu.convert_layout %123 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>>
%125 = arith.mulf %124, %39 : tensor<128x128xf32, #mma0> %125 = tt.dot %124, %122, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0>
%126 = arith.truncf %125 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> %126 = arith.mulf %107, %125 : tensor<128x128xf32, #mma0>
%127 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> %127 = arith.mulf %126, %39 : tensor<128x128xf32, #mma0>
%128 = tt.trans %127 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> %128 = arith.truncf %127 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0>
%129 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %129 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0>
%130 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %130 = tt.trans %129 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1>
%131 = tt.dot %130, %129, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %131 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2>
%132 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> %132 = triton_gpu.convert_layout %131 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%133 = triton_gpu.convert_layout %132 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> %133 = triton_gpu.convert_layout %130 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%134 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %134 = tt.dot %133, %132, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
%135 = tt.dot %134, %79, %133 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %135 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2>
%136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> %136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1>
tt.store %arg29, %136 : tensor<128x64xf32, #blocked2> %137 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%137 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2> %138 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2>
%138 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %139 = triton_gpu.convert_layout %138 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>>
%139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %140 = tt.dot %137, %139, %136 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1>
scf.yield %114, %131, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1> %141 = triton_gpu.convert_layout %140 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2>
tt.store %arg29, %141 : tensor<128x64xf32, #blocked2>
%142 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
%143 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%144 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
scf.yield %115, %134, %142, %143, %144 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64x!tt.ptr<f16>, #blocked1>
} }
%81 = arith.truncf %80#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> %80 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
%82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %81 = tt.addptr %44, %63 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> %82 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
tt.store %82, %83 : tensor<128x64xf16, #blocked1> tt.store %81, %82 : tensor<128x64xf16, #blocked1>
%84 = arith.truncf %80#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> %83 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1>
%85 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%86 = triton_gpu.convert_layout %84 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> %85 = triton_gpu.convert_layout %83 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1>
tt.store %85, %86 : tensor<128x64xf16, #blocked1> tt.store %84, %85 : tensor<128x64xf16, #blocked1>
} }
return return
} }