From 3b80801dff1ec9093d9b4f8c4254158146595b4a Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 27 Oct 2022 22:09:06 -0700 Subject: [PATCH] [Triton-MLIR][Backend] Fix many problems to get the pipeline working (#809) 1. Rewrite code generation of insert_slice_async. 2. Correct the wrong index passed to extract_slice in pipeline. 3. Add a prologue in pipeline to wait for dangling cp.asyncs. 4. Move scf to cf conversion inside TritonGPUToLLVM because we need to perform membar before scf to cf. It shouldn't be a technical limitation and could be improved by a more general membar analysis. 5. Use an attribute to memoize the shared memory size and support dynamic shared memory. 6. Prevent the combine pass to reorder insert_slice and extract_slice across async_wait Co-authored-by: Superjomn --- .../Conversion/TritonGPUToLLVM/PtxAsmFormat.h | 4 +- lib/Analysis/Membar.cpp | 6 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 86 ++++++++++++------- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 12 ++- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 22 +++++ lib/Target/LLVMIR/LLVMIRTranslation.cpp | 4 +- python/src/triton.cc | 4 +- python/triton/compiler.py | 2 +- test/Conversion/tritongpu_to_llvm.mlir | 46 +++++----- test/TritonGPU/loop-pipeline.mlir | 6 +- 10 files changed, 122 insertions(+), 70 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h index 82c20a639..3ad8316ba 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h @@ -292,13 +292,11 @@ struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase { struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase { explicit PTXCpAsyncLoadInstr(PTXBuilder *builder, - triton::CacheModifier modifier, - triton::EvictionPolicy policy) + triton::CacheModifier modifier) : PTXCpAsyncInstrBase(builder) { o(triton::stringifyCacheModifier(modifier).str()); o("shared"); o("global"); - o("L2::" + triton::stringifyEvictionPolicy(policy).str()); } }; diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index e2d386fd4..6b03c9947 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -42,18 +42,14 @@ void MembarAnalysis::dfsOperation(Operation *operation, void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo, OpBuilder *builder) { - if (op->getNumResults() < 1) - return; - if (isa(op) || isa(op) || isa(op) || isa(op) || - isa(op) || isa(op)) { // Do not insert barriers before control flow operations and // alloc/extract/insert // alloc is an allocation op without memory write. // In contrast, arith.constant is an allocation op with memory write. - // FIXME(Keren): extract and insert are always alias for now + // FIXME(Keren): extract is always alias for now return; } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index fda2f4a61..50470ac88 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -6,6 +6,7 @@ #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" @@ -14,6 +15,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" @@ -108,7 +110,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, #define i32_ty rewriter.getIntegerType(32) #define f32_ty rewriter.getF32Type() #define vec_ty(type, num) VectorType::get(num, type) -#define void_ty LLVM::LLVMVoidType::get(ctx) +#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) #define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__) // Creator for constant @@ -360,7 +362,7 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, auto *ptrOpr = builder.newAddrOperand(ptr, "r"); auto *valOpr = builder.newOperand(val, c); st(ptrOpr, valOpr).predicate(pred, "b"); - return builder.launch(rewriter, loc, void_ty); + return builder.launch(rewriter, loc, void_ty(ctx)); } struct ConvertTritonGPUOpToLLVMPatternBase { @@ -1151,7 +1153,7 @@ struct StoreOpConversion llvm::SmallVector argTys({boolTy, ptr.getType()}); argTys.insert(argTys.end(), nWords, valArgTy); - auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx); + auto ASMReturnTy = void_ty(ctx); ptxBuilder.launch(rewriter, loc, ASMReturnTy); } @@ -3869,7 +3871,7 @@ struct AsyncWaitOpConversion auto ctx = op.getContext(); auto loc = op.getLoc(); - auto voidTy = LLVM::LLVMVoidType::get(ctx); + auto voidTy = void_ty(ctx); auto ret = ptxBuilder.launch(rewriter, loc, voidTy); // Safe to remove the op since it doesn't have any return value. @@ -3907,7 +3909,7 @@ struct InsertSliceAsyncOpConversion auto srcTy = src.getType().cast(); auto resTy = dst.getType().cast(); - auto resElemTy = resTy.getElementType(); + auto resElemTy = getTypeConverter()->convertType(resTy.getElementType()); auto srcBlockedLayout = srcTy.getEncoding().cast(); auto resSharedLayout = resTy.getEncoding().cast(); auto srcShape = srcTy.getShape(); @@ -3928,7 +3930,7 @@ struct InsertSliceAsyncOpConversion assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now"); auto dstBase = createIndexAttrConstant(rewriter, loc, getTypeConverter()->getIndexType(), - product(resTy.getShape())); + product(srcTy.getShape())); Value offset = mul(llIndex, dstBase); auto dstPtrTy = LLVM::LLVMPointerType::get( getTypeConverter()->convertType(resTy.getElementType()), 3); @@ -4039,40 +4041,42 @@ struct InsertSliceAsyncOpConversion auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth(); // XXX(Keren): Tune CG and CA here. + auto byteWidth = bitWidth / 8; CacheModifier srcCacheModifier = - bitWidth == 128 ? CacheModifier::CG : CacheModifier::CA; - assert(bitWidth == 128 || bitWidth == 64 || bitWidth == 32); + byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA; + assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4); + auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8; - for (int wordIdx = 0; wordIdx < numWords; ++wordIdx) { + auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}]; + for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) { PTXBuilder ptxBuilder; - auto ©AsyncOp = *ptxBuilder.create( - srcCacheModifier, op.evict()); - - auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}]; - auto *dstOperand = - ptxBuilder.newAddrOperand(tileOffset, "r", baseOffset); - auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[vecIdx], "l"); - auto *copySize = ptxBuilder.newConstantOperand(bitWidth); + auto wordElemIdx = wordIdx * numWordElems; + auto ©AsyncOp = + *ptxBuilder.create(srcCacheModifier); + auto *dstOperand = ptxBuilder.newAddrOperand( + tileOffset, "r", (wordElemIdx + baseOffset) * resByteWidth); + auto *srcOperand = + ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l"); + auto *copySize = ptxBuilder.newConstantOperand(byteWidth); auto *srcSize = copySize; if (op.mask()) { // We don't use predicate in this case, setting src-size to 0 // if there's any mask. cp.async will automatically fill the // remaining slots with 0 if cp-size > src-size. // XXX(Keren): Always assume other = 0 for now. - auto selectOp = select(maskElems[vecIdx + wordIdx * numWordElems], - i32_val(bitWidth), i32_val(0)); + auto selectOp = select(maskElems[elemIdx + wordElemIdx], + i32_val(byteWidth), i32_val(0)); srcSize = ptxBuilder.newOperand(selectOp, "r"); } copyAsyncOp(dstOperand, srcOperand, copySize, srcSize); - ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext())); + ptxBuilder.launch(rewriter, loc, void_ty(getContext())); } } PTXBuilder ptxBuilder; ptxBuilder.create()->operator()(); - auto ret = - ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext())); - rewriter.replaceOp(op, ret); + ptxBuilder.launch(rewriter, loc, void_ty(getContext())); + rewriter.replaceOp(op, llDst); return success(); } }; @@ -4172,21 +4176,39 @@ public: int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - // step 1: Convert FuncOp to LLVMFuncOp via partial conversion - // step 2: Allocate for shared memories - // step 3: Convert the rest of ops via partial conversion - // The reason for a seperation between 1/3 is that, step 2 is out of + // step 1: Allocate shared memories and insert barriers + // setp 2: Convert SCF to CFG + // step 3: Convert FuncOp to LLVMFuncOp via partial conversion + // step 4: Convert the rest of ops via partial conversion + // The reason for putting step 1 before step 2 is that the membar analysis + // currently only supports SCF but not CFG. + // The reason for a seperation between 1/4 is that, step 3 is out of // the scope of Dialect Conversion, thus we need to make sure the smem - // is not revised during the conversion of step 3. + // is not revised during the conversion of step 4. + Allocation allocation(mod); + MembarAnalysis membar(&allocation); + + RewritePatternSet scf_patterns(context); + mlir::populateLoopToStdConversionPatterns(scf_patterns); + mlir::ConversionTarget scf_target(*context); + scf_target.addIllegalOp(); + scf_target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + if (failed( + applyPartialConversion(mod, scf_target, std::move(scf_patterns)))) + return signalPassFailure(); + RewritePatternSet func_patterns(context); func_patterns.add(typeConverter, numWarps, 1 /*benefit*/); if (failed( applyPartialConversion(mod, funcTarget, std::move(func_patterns)))) return signalPassFailure(); - Allocation allocation(mod); auto axisAnalysis = runAxisAnalysis(mod); initSharedMemory(allocation.getSharedMemorySize(), typeConverter); + mod->setAttr("triton_gpu.shared", + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32), + allocation.getSharedMemorySize())); // We set a higher benefit here to ensure triton's patterns runs before // arith patterns for some encoding not supported by the community @@ -4229,9 +4251,11 @@ void ConvertTritonGPUToLLVM::initSharedMemory( OpBuilder b(mod.getBodyRegion()); auto loc = mod.getLoc(); auto elemTy = typeConverter.convertType(b.getIntegerType(8)); - auto arrayTy = LLVM::LLVMArrayType::get(elemTy, size); + // Set array size 0 and external linkage indicates that we use dynamic shared + // allocation to allow a larger shared memory size for each kernel. + auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0); auto global = b.create( - loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::Internal, + loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External, "global_smem", /*value=*/Attribute(), /*alignment=*/0, mlir::gpu::GPUDialect::getWorkgroupAddressSpace()); SmallVector funcs; diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 74cd31b2a..0f363738f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -103,16 +103,21 @@ public: if (!arg) return mlir::failure(); // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) - // cvt(insert_slice(x), type2) -> extract_slice(cvt(x, type2)) auto alloc_tensor = dyn_cast(arg); if (alloc_tensor) { rewriter.replaceOpWithNewOp( op, op->getResult(0).getType()); return mlir::success(); } + // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) auto insert_slice = dyn_cast(arg); if (insert_slice) { auto newType = op->getResult(0).getType(); + // Ensure that the new insert_slice op is placed in the same place as the + // old insert_slice op. Otherwise, the new insert_slice op may be placed + // after the async_wait op, which is not allowed. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(insert_slice); auto new_arg = rewriter.create( op->getLoc(), newType, insert_slice.dst()); rewriter.replaceOpWithNewOp( @@ -126,6 +131,11 @@ public: auto extract_slice = dyn_cast(arg); if (extract_slice) { auto origType = extract_slice.src().getType().cast(); + // Ensure that the new extract_slice op is placed in the same place as the + // old extract_slice op. Otherwise, the new extract_slice op may be placed + // after the async_wait op, which is not allowed. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(extract_slice); auto newType = RankedTensorType::get( origType.getShape(), origType.getElementType(), op->getResult(0).getType().cast().getEncoding()); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index c22b7d5f4..ccb97aa52 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -78,6 +78,9 @@ public: /// emit pipelined loads (before loop body) void emitPrologue(); + /// emit pipelined loads (after loop body) + void emitEpilogue(); + /// create the new ForOp (add new args & insert prefetched ops) scf::ForOp createNewForOp(); @@ -362,6 +365,23 @@ void LoopPipeliner::emitPrologue() { loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0); loadsExtract[loadOp] = extractSlice; } + // bump up loopIterIdx, this is used for getting the correct slice for the + // *next* iteration + loopIterIdx = builder.create( + loopIterIdx.getLoc(), loopIterIdx, + builder.create(loopIterIdx.getLoc(), 1, 32)); +} + +void LoopPipeliner::emitEpilogue() { + // If there's any outstanding async copies, we need to wait for them. + // TODO(Keren): We may want to completely avoid the async copies in the last + // few iterations by setting is_masked attribute to true. We don't want to use + // the mask operand because it's a tensor but not a scalar. + OpBuilder builder(forOp); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointAfter(forOp); + Operation *asyncWait = + builder.create(forOp.getLoc(), 0); } scf::ForOp LoopPipeliner::createNewForOp() { @@ -581,6 +601,8 @@ struct PipelinePass : public TritonGPUPipelineBase { scf::ForOp newForOp = pipeliner.createNewForOp(); + pipeliner.emitEpilogue(); + // replace the original loop for (unsigned i = 0; i < forOp->getNumResults(); ++i) forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 5967a0e04..5ed79cd81 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -136,10 +136,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, /*printAfterOnlyOnChange=*/true, /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags); - pm.addPass(mlir::createLowerToCFGPass()); pm.addPass(createConvertTritonGPUToLLVMPass()); // Conanicalize to eliminate the remaining UnrealizedConversionCastOp pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability. + pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::createCanonicalizerPass()); if (failed(pm.run(module))) { llvm::errs() << "Pass execution failed"; diff --git a/python/src/triton.cc b/python/src/triton.cc index fb893219c..a15d2dda2 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1257,8 +1257,8 @@ void init_triton_translation(py::module &m) { using ret = py::return_value_policy; m.def("get_shared_memory_size", [](mlir::ModuleOp module) { - auto pass = std::make_unique(module); - return pass->getSharedMemorySize(); + return module->getAttrOfType("triton_gpu.shared") + .getInt(); }); m.def( diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 2ef5293e4..f711fde24 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -875,7 +875,7 @@ def optimize_tritongpu_ir(mod, num_stages): pm.enable_debug() # Get error in backend due to wrong conversion in expanding async-related instruction. # TODO[Superjomn]: Open it when fixed. - # pm.add_tritongpu_pipeline_pass(num_stages) + pm.add_tritongpu_pipeline_pass(num_stages) pm.add_canonicalizer_pass() pm.add_cse_pass() pm.add_coalesce_pass() diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 4652f6c2b..234277bd7 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -326,7 +326,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - // CHECK: llvm.mlir.global internal @global_smem + // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_alloc_tensor func @basic_alloc_tensor() { // CHECK: llvm.mlir.addressof @global_smem @@ -343,7 +343,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - // CHECK: llvm.mlir.global internal @global_smem + // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_extract_slice func @basic_extract_slice() { // CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem @@ -382,10 +382,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}> #slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}> #AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v4 - func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 8 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> @@ -404,9 +404,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x80, 0x80 + // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 8 ], [ ${{.*}} + 0 ], 0x80, 0x80 + // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.commit_group %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr, #AL> -> tensor<2x16x64xf32, #A> @@ -445,13 +445,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr, #AL> -> tensor<2x16x32xf32, #A> @@ -489,21 +489,21 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20 + // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr, #AL> -> tensor<2x32x32xf32, #A> @@ -545,7 +545,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { - // CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1088 x i8> + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_blocked func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem @@ -593,7 +593,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { - // CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1280 x i8> + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_blocked_vec func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem @@ -617,7 +617,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { - // CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<640 x i8> + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem @@ -682,7 +682,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> #mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { - // CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<2560 x i8> + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_mma_block func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) { // CHECK: nvvm.barrier0 @@ -703,7 +703,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { - // CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<16384 x i8> + // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_shared func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { // CHECK: llvm.store diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index ec0a06aee..a1d333cb6 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -22,7 +22,7 @@ // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] @@ -78,7 +78,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]] // CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] // CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] @@ -131,7 +131,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr