[Triton-MLIR][Backend] Fix many problems to get the pipeline working (#809)

1. Rewrite code generation of insert_slice_async.
2. Correct the wrong index passed to extract_slice in pipeline.
3. Add a prologue in pipeline to wait for dangling cp.asyncs.  
4. Move scf to cf conversion inside TritonGPUToLLVM because we need to
perform membar before scf to cf. It shouldn't be a technical limitation
and could be improved by a more general membar analysis.
5. Use an attribute to memoize the shared memory size and support
dynamic shared memory.
6. Prevent the combine pass to reorder insert_slice and extract_slice
across async_wait

Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Keren Zhou
2022-10-27 22:09:06 -07:00
committed by GitHub
parent 42db3538e4
commit 3b80801dff
10 changed files with 122 additions and 70 deletions

View File

@@ -6,6 +6,7 @@
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
@@ -14,6 +15,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
@@ -108,7 +110,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define i32_ty rewriter.getIntegerType(32)
#define f32_ty rewriter.getF32Type()
#define vec_ty(type, num) VectorType::get(num, type)
#define void_ty LLVM::LLVMVoidType::get(ctx)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__)
// Creator for constant
@@ -360,7 +362,7 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
auto *valOpr = builder.newOperand(val, c);
st(ptrOpr, valOpr).predicate(pred, "b");
return builder.launch(rewriter, loc, void_ty);
return builder.launch(rewriter, loc, void_ty(ctx));
}
struct ConvertTritonGPUOpToLLVMPatternBase {
@@ -1151,7 +1153,7 @@ struct StoreOpConversion
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
argTys.insert(argTys.end(), nWords, valArgTy);
auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx);
auto ASMReturnTy = void_ty(ctx);
ptxBuilder.launch(rewriter, loc, ASMReturnTy);
}
@@ -3869,7 +3871,7 @@ struct AsyncWaitOpConversion
auto ctx = op.getContext();
auto loc = op.getLoc();
auto voidTy = LLVM::LLVMVoidType::get(ctx);
auto voidTy = void_ty(ctx);
auto ret = ptxBuilder.launch(rewriter, loc, voidTy);
// Safe to remove the op since it doesn't have any return value.
@@ -3907,7 +3909,7 @@ struct InsertSliceAsyncOpConversion
auto srcTy = src.getType().cast<RankedTensorType>();
auto resTy = dst.getType().cast<RankedTensorType>();
auto resElemTy = resTy.getElementType();
auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
auto srcShape = srcTy.getShape();
@@ -3928,7 +3930,7 @@ struct InsertSliceAsyncOpConversion
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now");
auto dstBase = createIndexAttrConstant(rewriter, loc,
getTypeConverter()->getIndexType(),
product<int64_t>(resTy.getShape()));
product<int64_t>(srcTy.getShape()));
Value offset = mul(llIndex, dstBase);
auto dstPtrTy = LLVM::LLVMPointerType::get(
getTypeConverter()->convertType(resTy.getElementType()), 3);
@@ -4039,40 +4041,42 @@ struct InsertSliceAsyncOpConversion
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
// XXX(Keren): Tune CG and CA here.
auto byteWidth = bitWidth / 8;
CacheModifier srcCacheModifier =
bitWidth == 128 ? CacheModifier::CG : CacheModifier::CA;
assert(bitWidth == 128 || bitWidth == 64 || bitWidth == 32);
byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA;
assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4);
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
for (int wordIdx = 0; wordIdx < numWords; ++wordIdx) {
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
for (unsigned wordIdx = 0; wordIdx < numWords; ++wordIdx) {
PTXBuilder ptxBuilder;
auto &copyAsyncOp = *ptxBuilder.create<PTXCpAsyncLoadInstr>(
srcCacheModifier, op.evict());
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
auto *dstOperand =
ptxBuilder.newAddrOperand(tileOffset, "r", baseOffset);
auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[vecIdx], "l");
auto *copySize = ptxBuilder.newConstantOperand(bitWidth);
auto wordElemIdx = wordIdx * numWordElems;
auto &copyAsyncOp =
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
auto *dstOperand = ptxBuilder.newAddrOperand(
tileOffset, "r", (wordElemIdx + baseOffset) * resByteWidth);
auto *srcOperand =
ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l");
auto *copySize = ptxBuilder.newConstantOperand(byteWidth);
auto *srcSize = copySize;
if (op.mask()) {
// We don't use predicate in this case, setting src-size to 0
// if there's any mask. cp.async will automatically fill the
// remaining slots with 0 if cp-size > src-size.
// XXX(Keren): Always assume other = 0 for now.
auto selectOp = select(maskElems[vecIdx + wordIdx * numWordElems],
i32_val(bitWidth), i32_val(0));
auto selectOp = select(maskElems[elemIdx + wordElemIdx],
i32_val(byteWidth), i32_val(0));
srcSize = ptxBuilder.newOperand(selectOp, "r");
}
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
}
}
PTXBuilder ptxBuilder;
ptxBuilder.create<PTXCpAsyncCommitGroupInstr>()->operator()();
auto ret =
ptxBuilder.launch(rewriter, loc, LLVM::LLVMVoidType::get(getContext()));
rewriter.replaceOp(op, ret);
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
rewriter.replaceOp(op, llDst);
return success();
}
};
@@ -4172,21 +4176,39 @@ public:
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
// step 1: Convert FuncOp to LLVMFuncOp via partial conversion
// step 2: Allocate for shared memories
// step 3: Convert the rest of ops via partial conversion
// The reason for a seperation between 1/3 is that, step 2 is out of
// step 1: Allocate shared memories and insert barriers
// setp 2: Convert SCF to CFG
// step 3: Convert FuncOp to LLVMFuncOp via partial conversion
// step 4: Convert the rest of ops via partial conversion
// The reason for putting step 1 before step 2 is that the membar analysis
// currently only supports SCF but not CFG.
// The reason for a seperation between 1/4 is that, step 3 is out of
// the scope of Dialect Conversion, thus we need to make sure the smem
// is not revised during the conversion of step 3.
// is not revised during the conversion of step 4.
Allocation allocation(mod);
MembarAnalysis membar(&allocation);
RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context);
scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp,
scf::WhileOp, scf::ExecuteRegionOp>();
scf_target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(mod, scf_target, std::move(scf_patterns))))
return signalPassFailure();
RewritePatternSet func_patterns(context);
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, 1 /*benefit*/);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
return signalPassFailure();
Allocation allocation(mod);
auto axisAnalysis = runAxisAnalysis(mod);
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
mod->setAttr("triton_gpu.shared",
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
allocation.getSharedMemorySize()));
// We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community
@@ -4229,9 +4251,11 @@ void ConvertTritonGPUToLLVM::initSharedMemory(
OpBuilder b(mod.getBodyRegion());
auto loc = mod.getLoc();
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, size);
// Set array size 0 and external linkage indicates that we use dynamic shared
// allocation to allow a larger shared memory size for each kernel.
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
auto global = b.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::Internal,
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,
"global_smem", /*value=*/Attribute(),
/*alignment=*/0, mlir::gpu::GPUDialect::getWorkgroupAddressSpace());
SmallVector<LLVM::LLVMFuncOp> funcs;