[Triton-MLIR][Backend] Fix many problems to get the pipeline working (#809)
1. Rewrite code generation of insert_slice_async. 2. Correct the wrong index passed to extract_slice in pipeline. 3. Add a prologue in pipeline to wait for dangling cp.asyncs. 4. Move scf to cf conversion inside TritonGPUToLLVM because we need to perform membar before scf to cf. It shouldn't be a technical limitation and could be improved by a more general membar analysis. 5. Use an attribute to memoize the shared memory size and support dynamic shared memory. 6. Prevent the combine pass to reorder insert_slice and extract_slice across async_wait Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -292,13 +292,11 @@ struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase {
|
||||
|
||||
struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase {
|
||||
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
|
||||
triton::CacheModifier modifier,
|
||||
triton::EvictionPolicy policy)
|
||||
triton::CacheModifier modifier)
|
||||
: PTXCpAsyncInstrBase(builder) {
|
||||
o(triton::stringifyCacheModifier(modifier).str());
|
||||
o("shared");
|
||||
o("global");
|
||||
o("L2::" + triton::stringifyEvictionPolicy(policy).str());
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -42,18 +42,14 @@ void MembarAnalysis::dfsOperation(Operation *operation,
|
||||
|
||||
void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
OpBuilder *builder) {
|
||||
if (op->getNumResults() < 1)
|
||||
return;
|
||||
|
||||
if (isa<scf::ForOp>(op) || isa<scf::IfOp>(op) || isa<scf::YieldOp>(op) ||
|
||||
isa<triton::gpu::ExtractSliceOp>(op) ||
|
||||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||
isa<triton::gpu::AllocTensorOp>(op)) {
|
||||
// Do not insert barriers before control flow operations and
|
||||
// alloc/extract/insert
|
||||
// alloc is an allocation op without memory write.
|
||||
// In contrast, arith.constant is an allocation op with memory write.
|
||||
// FIXME(Keren): extract and insert are always alias for now
|
||||
// FIXME(Keren): extract is always alias for now
|
||||
return;
|
||||
}
|
||||
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
@@ -14,6 +15,7 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Membar.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
@@ -108,7 +110,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define f32_ty rewriter.getF32Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
#define void_ty LLVM::LLVMVoidType::get(ctx)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__)
|
||||
|
||||
// Creator for constant
|
||||
@@ -360,7 +362,7 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
||||
auto *valOpr = builder.newOperand(val, c);
|
||||
st(ptrOpr, valOpr).predicate(pred, "b");
|
||||
return builder.launch(rewriter, loc, void_ty);
|
||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||
}
|
||||
|
||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
@@ -1151,7 +1153,7 @@ struct StoreOpConversion
|
||||
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
|
||||
argTys.insert(argTys.end(), nWords, valArgTy);
|
||||
|
||||
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
|
||||
ptxBuilder.launch(rewriter, loc, ASMReturnTy);
|
||||
}
|
||||
@@ -3869,7 +3871,7 @@ struct AsyncWaitOpConversion
|
||||
|
||||
auto ctx = op.getContext();
|
||||
auto loc = op.getLoc();
|
||||
auto voidTy = LLVM::LLVMVoidType::get(ctx);
|
||||
auto voidTy = void_ty(ctx);
|
||||
auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
|
||||
|
||||
// Safe to remove the op since it doesn't have any return value.
|
||||
@@ -3907,7 +3909,7 @@ struct InsertSliceAsyncOpConversion
|
||||
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto resTy = dst.getType().cast<RankedTensorType>();
|
||||
auto resElemTy = resTy.getElementType();
|
||||
auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
|
||||
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
@@ -3928,7 +3930,7 @@ struct InsertSliceAsyncOpConversion
|
||||
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now");
|
||||
auto dstBase = createIndexAttrConstant(rewriter, loc,
|
||||
getTypeConverter()->getIndexType(),
|
||||
product<int64_t>(resTy.getShape()));
|
||||
product<int64_t>(srcTy.getShape()));
|
||||
Value offset = mul(llIndex, dstBase);
|
||||
auto dstPtrTy = LLVM::LLVMPointerType::get(
|
||||
getTypeConverter()->convertType(resTy.getElementType()), 3);
|
||||
@@ -4039,40 +4041,42 @@ struct InsertSliceAsyncOpConversion
|
||||
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
|
||||
|
||||
// XXX(Keren): Tune CG and CA here.
|
||||
auto byteWidth = bitWidth / 8;
|
||||
CacheModifier srcCacheModifier =
|
||||
bitWidth == 128 ? CacheModifier::CG : CacheModifier::CA;
|
||||
assert(bitWidth == 128 || bitWidth == 64 || bitWidth == 32);
|
||||
|
||||
for (int wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto ©AsyncOp = *ptxBuilder.create<PTXCpAsyncLoadInstr>(
|
||||
srcCacheModifier, op.evict());
|
||||
byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA;
|
||||
assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4);
|
||||
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
|
||||
|
||||
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
||||
auto *dstOperand =
|
||||
ptxBuilder.newAddrOperand(tileOffset, "r", baseOffset);
|
||||
auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[vecIdx], "l");
|
||||
auto *copySize = ptxBuilder.newConstantOperand(bitWidth);
|
||||
for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto wordElemIdx = wordIdx * numWordElems;
|
||||
auto ©AsyncOp =
|
||||
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
|
||||
auto *dstOperand = ptxBuilder.newAddrOperand(
|
||||
tileOffset, "r", (wordElemIdx + baseOffset) * resByteWidth);
|
||||
auto *srcOperand =
|
||||
ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l");
|
||||
auto *copySize = ptxBuilder.newConstantOperand(byteWidth);
|
||||
auto *srcSize = copySize;
|
||||
if (op.mask()) {
|
||||
// We don't use predicate in this case, setting src-size to 0
|
||||
// if there's any mask. cp.async will automatically fill the
|
||||
// remaining slots with 0 if cp-size > src-size.
|
||||
// XXX(Keren): Always assume other = 0 for now.
|
||||
auto selectOp = select(maskElems[vecIdx + wordIdx * numWordElems],
|
||||
i32_val(bitWidth), i32_val(0));
|
||||
auto selectOp = select(maskElems[elemIdx + wordElemIdx],
|
||||
i32_val(byteWidth), i32_val(0));
|
||||
srcSize = ptxBuilder.newOperand(selectOp, "r");
|
||||
}
|
||||
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
|
||||
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
|
||||
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
|
||||
}
|
||||
}
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
ptxBuilder.create<PTXCpAsyncCommitGroupInstr>()->operator()();
|
||||
auto ret =
|
||||
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
|
||||
rewriter.replaceOp(op, ret);
|
||||
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
|
||||
rewriter.replaceOp(op, llDst);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -4172,21 +4176,39 @@ public:
|
||||
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
// step 1: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// step 2: Allocate for shared memories
|
||||
// step 3: Convert the rest of ops via partial conversion
|
||||
// The reason for a seperation between 1/3 is that, step 2 is out of
|
||||
// step 1: Allocate shared memories and insert barriers
|
||||
// setp 2: Convert SCF to CFG
|
||||
// step 3: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// step 4: Convert the rest of ops via partial conversion
|
||||
// The reason for putting step 1 before step 2 is that the membar analysis
|
||||
// currently only supports SCF but not CFG.
|
||||
// The reason for a seperation between 1/4 is that, step 3 is out of
|
||||
// the scope of Dialect Conversion, thus we need to make sure the smem
|
||||
// is not revised during the conversion of step 3.
|
||||
// is not revised during the conversion of step 4.
|
||||
Allocation allocation(mod);
|
||||
MembarAnalysis membar(&allocation);
|
||||
|
||||
RewritePatternSet scf_patterns(context);
|
||||
mlir::populateLoopToStdConversionPatterns(scf_patterns);
|
||||
mlir::ConversionTarget scf_target(*context);
|
||||
scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp,
|
||||
scf::WhileOp, scf::ExecuteRegionOp>();
|
||||
scf_target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||
if (failed(
|
||||
applyPartialConversion(mod, scf_target, std::move(scf_patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
RewritePatternSet func_patterns(context);
|
||||
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, 1 /*benefit*/);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
Allocation allocation(mod);
|
||||
auto axisAnalysis = runAxisAnalysis(mod);
|
||||
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
|
||||
mod->setAttr("triton_gpu.shared",
|
||||
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
|
||||
allocation.getSharedMemorySize()));
|
||||
|
||||
// We set a higher benefit here to ensure triton's patterns runs before
|
||||
// arith patterns for some encoding not supported by the community
|
||||
@@ -4229,9 +4251,11 @@ void ConvertTritonGPUToLLVM::initSharedMemory(
|
||||
OpBuilder b(mod.getBodyRegion());
|
||||
auto loc = mod.getLoc();
|
||||
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
|
||||
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, size);
|
||||
// Set array size 0 and external linkage indicates that we use dynamic shared
|
||||
// allocation to allow a larger shared memory size for each kernel.
|
||||
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
|
||||
auto global = b.create<LLVM::GlobalOp>(
|
||||
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::Internal,
|
||||
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,
|
||||
"global_smem", /*value=*/Attribute(),
|
||||
/*alignment=*/0, mlir::gpu::GPUDialect::getWorkgroupAddressSpace());
|
||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||
|
@@ -103,16 +103,21 @@ public:
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
|
||||
// cvt(insert_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
|
||||
if (alloc_tensor) {
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::AllocTensorOp>(
|
||||
op, op->getResult(0).getType());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
||||
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
||||
if (insert_slice) {
|
||||
auto newType = op->getResult(0).getType();
|
||||
// Ensure that the new insert_slice op is placed in the same place as the
|
||||
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(insert_slice);
|
||||
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, insert_slice.dst());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
||||
@@ -126,6 +131,11 @@ public:
|
||||
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
|
||||
if (extract_slice) {
|
||||
auto origType = extract_slice.src().getType().cast<RankedTensorType>();
|
||||
// Ensure that the new extract_slice op is placed in the same place as the
|
||||
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(extract_slice);
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
|
@@ -78,6 +78,9 @@ public:
|
||||
/// emit pipelined loads (before loop body)
|
||||
void emitPrologue();
|
||||
|
||||
/// emit pipelined loads (after loop body)
|
||||
void emitEpilogue();
|
||||
|
||||
/// create the new ForOp (add new args & insert prefetched ops)
|
||||
scf::ForOp createNewForOp();
|
||||
|
||||
@@ -362,6 +365,23 @@ void LoopPipeliner::emitPrologue() {
|
||||
loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0);
|
||||
loadsExtract[loadOp] = extractSlice;
|
||||
}
|
||||
// bump up loopIterIdx, this is used for getting the correct slice for the
|
||||
// *next* iteration
|
||||
loopIterIdx = builder.create<arith::AddIOp>(
|
||||
loopIterIdx.getLoc(), loopIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(loopIterIdx.getLoc(), 1, 32));
|
||||
}
|
||||
|
||||
void LoopPipeliner::emitEpilogue() {
|
||||
// If there's any outstanding async copies, we need to wait for them.
|
||||
// TODO(Keren): We may want to completely avoid the async copies in the last
|
||||
// few iterations by setting is_masked attribute to true. We don't want to use
|
||||
// the mask operand because it's a tensor but not a scalar.
|
||||
OpBuilder builder(forOp);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
Operation *asyncWait =
|
||||
builder.create<triton::gpu::AsyncWaitOp>(forOp.getLoc(), 0);
|
||||
}
|
||||
|
||||
scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
@@ -581,6 +601,8 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
|
||||
scf::ForOp newForOp = pipeliner.createNewForOp();
|
||||
|
||||
pipeliner.emitEpilogue();
|
||||
|
||||
// replace the original loop
|
||||
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
|
||||
forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i));
|
||||
|
@@ -136,10 +136,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
/*printAfterOnlyOnChange=*/true,
|
||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
||||
|
||||
pm.addPass(mlir::createLowerToCFGPass());
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass());
|
||||
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
|
||||
pm.addPass(mlir::createSymbolDCEPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
llvm::errs() << "Pass execution failed";
|
||||
|
@@ -1257,8 +1257,8 @@ void init_triton_translation(py::module &m) {
|
||||
using ret = py::return_value_policy;
|
||||
|
||||
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
|
||||
auto pass = std::make_unique<mlir::Allocation>(module);
|
||||
return pass->getSharedMemorySize();
|
||||
return module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared")
|
||||
.getInt();
|
||||
});
|
||||
|
||||
m.def(
|
||||
|
@@ -875,7 +875,7 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm.enable_debug()
|
||||
# Get error in backend due to wrong conversion in expanding async-related instruction.
|
||||
# TODO[Superjomn]: Open it when fixed.
|
||||
# pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_coalesce_pass()
|
||||
|
@@ -326,7 +326,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem
|
||||
// CHECK: llvm.mlir.global external @global_smem
|
||||
// CHECK-LABEL: basic_alloc_tensor
|
||||
func @basic_alloc_tensor() {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
@@ -343,7 +343,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem
|
||||
// CHECK: llvm.mlir.global external @global_smem
|
||||
// CHECK-LABEL: basic_extract_slice
|
||||
func @basic_extract_slice() {
|
||||
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
|
||||
@@ -382,10 +382,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v4
|
||||
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
@@ -404,9 +404,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%index = arith.constant 1 : i32
|
||||
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x80, 0x80
|
||||
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: cp.async.cg.shared.global.L2::evict_normal [ ${{.*}} + 8 ], [ ${{.*}} + 0 ], 0x80, 0x80
|
||||
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: cp.async.commit_group
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64x!tt.ptr<f32>, #AL> -> tensor<2x16x64xf32, #A>
|
||||
@@ -445,13 +445,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%index = arith.constant 1 : i32
|
||||
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.commit_group
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x32x!tt.ptr<f32>, #AL> -> tensor<2x16x32xf32, #A>
|
||||
@@ -489,21 +489,21 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%index = arith.constant 1 : i32
|
||||
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.ca.shared.global.L2::evict_normal [ ${{.*}} + 512 ], [ ${{.*}} + 0 ], 0x20, 0x20
|
||||
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 2048 ], [ ${{.*}} + 0 ], 0x4, 0x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: cp.async.commit_group
|
||||
%a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32x!tt.ptr<f32>, #AL> -> tensor<2x32x32xf32, #A>
|
||||
@@ -545,7 +545,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1088 x i8>
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked
|
||||
func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
@@ -593,7 +593,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1280 x i8>
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked_vec
|
||||
func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
@@ -617,7 +617,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<640 x i8>
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
|
||||
func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
|
||||
// CHECK: llvm.mlir.addressof @global_smem
|
||||
@@ -682,7 +682,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<2560 x i8>
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_mma_block
|
||||
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||
// CHECK: nvvm.barrier0
|
||||
@@ -703,7 +703,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<16384 x i8>
|
||||
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
|
||||
// CHECK-LABEL: convert_layout_blocked_shared
|
||||
func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||
// CHECK: llvm.store
|
||||
|
@@ -22,7 +22,7 @@
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]]
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
||||
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
@@ -78,7 +78,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]], %[[CONSTANT_0]]
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
||||
// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
@@ -131,7 +131,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
|
||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]], %[[CONSTANT_0]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
||||
// CHECK: tt.dot {{.*}}, %[[arg_b0]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
|
Reference in New Issue
Block a user