From 7027af96668cc319aae37eb8932892c4aa19f35a Mon Sep 17 00:00:00 2001 From: Yan Da Date: Sun, 15 May 2022 22:29:27 +0800 Subject: [PATCH] The pipeline pass is now functional --- .../Dialect/TritonGPU/Transforms/Passes.td | 4 +- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 67 +++++--- python/triton/code_gen.py | 1 + rewrite-test/jit/matmul/matmul.mlir | 145 +++++++++--------- 4 files changed, 119 insertions(+), 98 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 4e50bcbd4..827e74485 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -21,7 +21,9 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { let constructor = "mlir::triton::gpu::createPipelinePass"; - let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithmeticDialect"]; } #endif diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index ec621d99f..9b6cb0073 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -2,6 +2,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "mlir/IR/BlockAndValueMapping.h" +#include //===----------------------------------------------------------------------===// // @@ -35,8 +36,6 @@ class LoopPipeliner { PipelineInfo info; /// value (in loop) => value at stage N DenseMap> valueMapping; - /// stage => loop condition - DenseMap loopConds; DenseSet depArgs; DenseSet depOps; @@ -142,7 +141,7 @@ void LoopPipeliner::emitPrologue() { // prologue from [0, numStage-1) auto yield = cast(forOp.getBody()->getTerminator()); - Value iv = forOp.getInductionVar(); + Value iv = forOp.getLowerBound(); for (int stage = 0; stage < numStages - 1; ++stage) { // special handling for induction variable as the increment is implicit if (stage != 0) @@ -152,7 +151,6 @@ void LoopPipeliner::emitPrologue() { // special handling for loop condition as there is no condition in ForOp Value loopCond = builder.create( iv.getLoc(), arith::CmpIPredicate::slt, iv, forOp.getUpperBound()); - loopConds[stage] = loopCond; // rematerialize peeled values SmallVector orderedDeps; @@ -192,8 +190,11 @@ scf::ForOp LoopPipeliner::createNewForOp() { // (original args), // (a at stage[0, numStages-1)), (b at stage[0, numStages-1)) // (depArgs at stage numStages-1) - // (iv at stage numStages-1), (loopCond at stage numStages-1) + // (iv at stage numStages-1) SmallVector newLoopArgs; + // We need this to update operands for yield + // original block arg => new arg's idx + DenseMap depArgsIdx; for (auto v : forOp.getIterOperands()) newLoopArgs.push_back(v); size_t aArgIdx = newLoopArgs.size(); @@ -203,11 +204,15 @@ scf::ForOp LoopPipeliner::createNewForOp() { for (int i = 0; i < numStages - 1; ++i) newLoopArgs.push_back(valueMapping[info.dotOp.b()][i]); size_t depArgsBeginIdx = newLoopArgs.size(); - for (BlockArgument depArg : depArgs) + for (BlockArgument depArg : depArgs) { + depArgsIdx[depArg] = newLoopArgs.size(); newLoopArgs.push_back(valueMapping[depArg][numStages-1]); + } size_t nextIVIdx = newLoopArgs.size(); - newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-1]); - newLoopArgs.push_back(loopConds[numStages-1]); + newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages-2]); + + for (size_t i = 0; i < newLoopArgs.size(); ++i) + assert(newLoopArgs[i]); // signature of the new ForOp auto newForOp = builder.create(forOp.getLoc(), @@ -221,13 +226,18 @@ scf::ForOp LoopPipeliner::createNewForOp() { BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); - mapping.map(info.dotOp.a(), newForOp.getRegionIterArgs()[aArgIdx]); - mapping.map(info.dotOp.b(), newForOp.getRegionIterArgs()[bArgIdx]); + // mapping.map(info.dotOp.a(), newForOp.getRegionIterArgs()[aArgIdx]); + // mapping.map(info.dotOp.b(), newForOp.getRegionIterArgs()[bArgIdx]); for (Operation &op : forOp.getBody()->without_terminator()) { Operation *newOp = builder.clone(op, mapping); // update mapping of results for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); + // TODO: why doesn't mapping work? + if (&op == info.dotOp.getOperation()) { + newOp->setOperand(0, newForOp.getRegionIterArgs()[aArgIdx]); + newOp->setOperand(1, newForOp.getRegionIterArgs()[bArgIdx]); + } } // prefetch next iteration SmallVector orderedDeps; @@ -236,7 +246,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { orderedDeps.push_back(&op); assert(depOps.size() == orderedDeps.size() && "depOps contains invalid values"); BlockAndValueMapping nextMapping; - BlockAndValueMapping depArgsMapping; + DenseMap depArgsMapping; size_t argIdx = 0; for (BlockArgument arg : depArgs) { nextMapping.map(arg, newForOp.getRegionIterArgs()[argIdx + depArgsBeginIdx]); @@ -264,8 +274,19 @@ scf::ForOp LoopPipeliner::createNewForOp() { } Operation *nextOp = builder.clone(*op, nextMapping); // update mapping of results - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) + for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) { nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); + // if this is a loop-carried value, update the mapping for yield + auto originYield = cast(forOp.getBody()->getTerminator()); + for (OpOperand &operand : originYield->getOpOperands()) { + if (operand.get() == op->getResult(dstIdx)) { + size_t originIdx = operand.getOperandNumber(); + size_t newArgIdx = depArgsIdx[forOp.getRegionIterArgs()[originIdx]]; + BlockArgument newArg = newForOp.getRegionIterArgs()[newArgIdx]; + depArgsMapping[newArg] = nextOp->getResult(dstIdx); + } + } + } } // Finally, the YieldOp, need to sync with the order of newLoopArgs @@ -274,14 +295,16 @@ scf::ForOp LoopPipeliner::createNewForOp() { yieldValues.push_back(mapping.lookup(v)); for (int i = 1; i < numStages - 1; ++i) yieldValues.push_back(newForOp.getRegionIterArgs()[aArgIdx + i]); - yieldValues.push_back(nextMapping.lookup(info.aLoadOp.getResult())); + yieldValues.push_back(nextMapping.lookup(info.dotOp.a())); for (int i = 1; i < numStages - 1; ++i) yieldValues.push_back(newForOp.getRegionIterArgs()[bArgIdx + i]); - yieldValues.push_back(nextMapping.lookup(info.bLoadOp.getResult())); - // TODO: deps - // + yieldValues.push_back(nextMapping.lookup(info.dotOp.b())); + for (size_t i = depArgsBeginIdx; i < nextIVIdx; ++i) + yieldValues.push_back(depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); yieldValues.push_back(nextIV); - yieldValues.push_back(nextLoopCond); + builder.setInsertionPointToEnd(newForOp.getBody()); + builder.create(forOp.getBody()->getTerminator()->getLoc(), + yieldValues); return newForOp; } @@ -300,16 +323,14 @@ struct PipelinePass : public TritonGPUPipelineBase { if (pipeliner.initialize().failed()) return; - llvm::errs() << "candidate for pipelining: " << pipeliner.info.dotOp - << "\n"; - pipeliner.emitPrologue(); scf::ForOp newForOp = pipeliner.createNewForOp(); - // // replace the original loop - // forOp->replaceAllUsesWith(newForOp->getResults()); - // forOp->erase(); + // replace the original loop + for (unsigned i = 0; i < forOp->getNumResults(); ++i) + forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); + forOp->erase(); }); } }; diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index b5999648b..2ed06528a 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1315,6 +1315,7 @@ class JITFunction: pm.add_canonicalizer_pass() pm.add_convert_triton_to_tritongpu_pass() pm.add_tritongpu_pipeline_pass() + pm.add_canonicalizer_pass() pm.run(mod) return mod diff --git a/rewrite-test/jit/matmul/matmul.mlir b/rewrite-test/jit/matmul/matmul.mlir index eb240cd28..a93754229 100644 --- a/rewrite-test/jit/matmul/matmul.mlir +++ b/rewrite-test/jit/matmul/matmul.mlir @@ -128,41 +128,17 @@ module { return %1 : i32 } } -is yield legal? -scf.yield %80, %82, %85 : tensor<128x128xf32>, tensor<128x128x!tt.ptr>, tensor<128x128x!tt.ptr> -is legal: 0 -converting for...: -%78:3 = scf.for %arg9 = %c0 to %77 step %c128 iter_args(%arg10 = %75, %arg11 = %51, %arg12 = %73) -> (tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>) { - %109 = tt.load <>, %cst_2, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16> - %110 = tt.load <>, %cst_2, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16> - %111 = tt.dot %109, %110, <> {allowTF32 = true} : tensor<128x128xf16> * tensor<128x128xf16> -> tensor<128x128xf32> - %112 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32> - %113 = tt.getelementptr <>, %112, : tensor<128x128x!tt.ptr> - %114 = arith.muli %arg7, %c128_i32 : i32 - %115 = tt.broadcast %114 : (i32) -> tensor<128x128xi32> - %116 = tt.getelementptr <>, %115, : tensor<128x128x!tt.ptr> - scf.yield %111, %113, %116 : tensor<128x128xf32>, tensor<128x128x!tt.ptr>, tensor<128x128x!tt.ptr> -} -converting dot... -%113 = tt.dot %109, %111, %arg10 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> * tensor<128x128xf16, #triton_gpu<"coalesced encoding">> -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> -is yield legal? -scf.yield %114, %118, %123 : tensor<128x128xf32>, tensor<128x128x!tt.ptr>, tensor<128x128x!tt.ptr> -is legal: 0 -converting yield.... -is yield legal? -scf.yield %113, %117, %122 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> -is legal: 1 module { func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c128_13c128_14c128_15c8(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { - %c1_i32 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - %cst_0 = arith.constant dense : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %c128 = arith.constant 128 : index - %c0 = arith.constant 0 : index - %cst_1 = arith.constant 0.000000e+00 : f32 - %c128_i32 = arith.constant 128 : i32 - %c127_i32 = arith.constant 127 : i32 %c8_i32 = arith.constant 8 : i32 + %c127_i32 = arith.constant 127 : i32 + %c128_i32 = arith.constant 128 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %cst_0 = arith.constant dense : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %c1_i32 = arith.constant 1 : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.addi %arg3, %c127_i32 : i32 %2 = arith.divsi %1, %c128_i32 : i32 @@ -209,54 +185,75 @@ module { %43 = arith.addi %41, %42 : tensor<128x128xi32, #triton_gpu<"coalesced encoding">> %44 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> %45 = tt.getelementptr %44, %43, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %46 = tt.broadcast %cst_1 : (f32) -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> + %46 = tt.broadcast %cst : (f32) -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> %47 = arith.index_cast %arg5 : i32 to index - %48:3 = scf.for %arg9 = %c0 to %47 step %c128 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45) -> (tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>) { - %78 = tt.load %arg11, %cst_0, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - %79 = tt.load %arg12, %cst_0, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - %80 = tt.dot %78, %79, %arg10 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> * tensor<128x128xf16, #triton_gpu<"coalesced encoding">> -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> - %81 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %82 = tt.getelementptr %arg11, %81, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %83 = arith.muli %arg7, %c128_i32 : i32 - %84 = tt.broadcast %83 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %85 = tt.getelementptr %arg12, %84, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - scf.yield %80, %82, %85 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %48 = tt.load %34, %cst_0, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %49 = tt.load %45, %cst_0, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %50 = "triton_gpu.convert_layout"(%48) : (tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> + %51 = "triton_gpu.convert_layout"(%49) : (tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> + %52 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %53 = tt.getelementptr %34, %52, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %54 = arith.muli %arg7, %c128_i32 : i32 + %55 = tt.broadcast %54 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %56 = tt.getelementptr %45, %55, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %57:8 = scf.for %arg9 = %c0 to %47 step %c128 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45, %arg13 = %50, %arg14 = %51, %arg15 = %56, %arg16 = %53, %arg17 = %c0) -> (tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, index) { + %87 = tt.dot %arg13, %arg14, %arg10 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> * tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> + %88 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %89 = tt.getelementptr %arg11, %88, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %90 = arith.muli %arg7, %c128_i32 : i32 + %91 = tt.broadcast %90 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %92 = tt.getelementptr %arg12, %91, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %93 = arith.addi %arg17, %c128 : index + %94 = arith.cmpi slt, %93, %47 : index + %95 = tt.broadcast %94 : (i1) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %96 = tt.load %arg16, %95, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %97 = tt.broadcast %94 : (i1) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %98 = arith.andi %97, %95 : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %99 = tt.load %arg15, %98, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %100 = "triton_gpu.convert_layout"(%96) : (tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> + %101 = "triton_gpu.convert_layout"(%99) : (tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> + %102 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %103 = tt.getelementptr %arg16, %102, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %104 = arith.muli %arg7, %c128_i32 : i32 + %105 = tt.broadcast %104 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %106 = tt.getelementptr %arg15, %105, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + scf.yield %87, %89, %92, %100, %101, %106, %103, %93 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, index } - %49 = arith.truncf %48#0 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">> to tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - %50 = arith.muli %12, %c128_i32 : i32 - %51 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %52 = tt.broadcast %50 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> - %53 = arith.addi %52, %51 : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %54 = arith.muli %14, %c128_i32 : i32 - %55 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %56 = tt.broadcast %54 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> - %57 = arith.addi %56, %55 : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %58 = tt.reshape %53 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %59 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %60 = arith.muli %59, %58 : tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %61 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> - %62 = tt.getelementptr %61, %60, : tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> - %63 = tt.reshape %57 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %64 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %65 = arith.muli %63, %64 : tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %66 = tt.broadcast %62 : (tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">>) -> tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %67 = tt.broadcast %65 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %68 = tt.getelementptr %66, %67, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %69 = tt.reshape %53 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %70 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %71 = "triton_gpu.cmpi"(%69, %70) {predicate = 2 : i64} : (tensor<128x1xi32, #triton_gpu<"coalesced encoding">>, tensor<128x1xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi1, #triton_gpu<"coalesced encoding">> - %72 = tt.reshape %57 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %73 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %74 = "triton_gpu.cmpi"(%72, %73) {predicate = 2 : i64} : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>, tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi1, #triton_gpu<"coalesced encoding">> - %75 = tt.broadcast %71 : (tensor<128x1xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %76 = tt.broadcast %74 : (tensor<1x128xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %77 = arith.andi %75, %76 : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - tt.store %68, %49, %77, : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %58 = arith.truncf %57#0 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">> to tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %59 = arith.muli %12, %c128_i32 : i32 + %60 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %61 = tt.broadcast %59 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> + %62 = arith.addi %61, %60 : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %63 = arith.muli %14, %c128_i32 : i32 + %64 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %65 = tt.broadcast %63 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> + %66 = arith.addi %65, %64 : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %67 = tt.reshape %62 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %68 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %69 = arith.muli %68, %67 : tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %70 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> + %71 = tt.getelementptr %70, %69, : tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> + %72 = tt.reshape %66 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %73 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %74 = arith.muli %72, %73 : tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %75 = tt.broadcast %71 : (tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">>) -> tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %76 = tt.broadcast %74 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %77 = tt.getelementptr %75, %76, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %78 = tt.reshape %62 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %79 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %80 = "triton_gpu.cmpi"(%78, %79) {predicate = 2 : i64} : (tensor<128x1xi32, #triton_gpu<"coalesced encoding">>, tensor<128x1xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi1, #triton_gpu<"coalesced encoding">> + %81 = tt.reshape %66 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %82 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %83 = "triton_gpu.cmpi"(%81, %82) {predicate = 2 : i64} : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>, tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi1, #triton_gpu<"coalesced encoding">> + %84 = tt.broadcast %80 : (tensor<128x1xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %85 = tt.broadcast %83 : (tensor<1x128xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %86 = arith.andi %84, %85 : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + tt.store %77, %58, %86, : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> return } func @"cdiv__i32__1cconstexpr[128]"(%arg0: i32) -> i32 { - %c127_i32 = arith.constant 127 : i32 %c128_i32 = arith.constant 128 : i32 + %c127_i32 = arith.constant 127 : i32 %0 = arith.addi %arg0, %c127_i32 : i32 %1 = arith.divsi %0, %c128_i32 : i32 return %1 : i32