[BACKEND] Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp & GEPOp and bugfix for SplatOp, StoreOp, FuncOp (#60)

Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp, GEPOp and bugfix for SplatOp, StoreOp, FuncOp

Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
goostavz
2022-08-18 20:46:45 +08:00
committed by GitHub
parent b1673caaf6
commit fc58250a06
9 changed files with 270 additions and 122 deletions

View File

@@ -61,7 +61,7 @@ std::string PtxInstr::Operand::dump() const {
if (repr)
return repr(idx);
if (!isList())
return llvm::formatv("%{0}", idx);
return llvm::formatv("${0}", idx);
llvm::SmallVector<std::string> oprs;
for (auto *opr : list)
oprs.push_back(opr->dump());
@@ -72,7 +72,7 @@ PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr,
StringRef constraint, int off) {
auto *opr = newOperand(addr, constraint);
opr->repr = [off](int idx) -> std::string {
return llvm::formatv("[ %{0} + {1} ]", idx, off);
return llvm::formatv("[ ${0} + {1} ]", idx, off);
};
return opr;

View File

@@ -46,20 +46,10 @@ template <typename Int> size_t product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}
// The following code are borrowed from mlir project including the following
// functions or classes:
// - filterFuncAttributes
// - ConvertOpToLLVMPattern
// - FuncOpConversion
//
// The code are hidden in the CPP files in MLIR repo, and we can't call them
// directly. I found such code snippets are refactored and added to LLVMCommon
// in the latest MLIR code, but the v14.0.0 version currentlly used in Triton
// doesn't contain the code.
// FuncOpConversion/FuncOpConversionBase is borrowed from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
// since it is not exposed on header files in mlir v14
// TODO(Superjomn) Remove the code when mlir v15.0 is included.
//
// The original code:
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L219
// All the rights are reserved by LLVM community.
/// Only retain those attributes that are not constructed by
@@ -79,6 +69,12 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
}
}
/// Helper function for wrapping all attributes into a single DictionaryAttr
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
return DictionaryAttr::get(
b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs));
}
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
protected:
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
@@ -90,25 +86,34 @@ protected:
ConversionPatternRewriter &rewriter) const {
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = getTypeConverter()->convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
if (!llvmType)
return nullptr;
// Propagate argument attributes to all converted arguments obtained after
// converting a given original argument.
// Propagate argument/result attributes to all converted arguments/result
// obtained after converting a given original argument/result.
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
attributes);
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
assert(!resAttrDicts.empty() && "expected array to be non-empty");
auto newResAttrDicts =
(funcOp.getNumResults() == 1)
? resAttrDicts
: rewriter.getArrayAttr(
{wrapAsStructAttrs(rewriter, resAttrDicts)});
attributes.push_back(rewriter.getNamedAttr(
FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
}
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
SmallVector<Attribute, 4> newArgAttrs(
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
auto mapping = result.getInputMapping(i);
assert(mapping.hasValue() &&
"unexpected deletion of function argument");
assert(mapping && "unexpected deletion of function argument");
for (size_t j = 0; j < mapping->size; ++j)
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
}
@@ -136,37 +141,15 @@ protected:
}
linkage = attr.getLinkage();
}
auto oldArgs = funcOp.getArguments();
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
&result)))
return nullptr;
// Convert argument
llvm::DenseMap<Value, Value> argMap;
for (int i = 0, n = funcOp.getNumArguments(); i < n; i++) {
Value oldArg = oldArgs[i];
Value newArg = newFuncOp.getArgument(i);
argMap.try_emplace(oldArg, newArg);
}
newFuncOp.getBody().walk([&](Operation *op) {
// Convert the function argument types, e.g, from !tt.ptr<fp16> to
// ptr<fp16>
for (int i = 0; i < op->getNumOperands(); i++) {
auto arg = op->getOperand(i);
auto it = argMap.find(arg);
if (it != argMap.end())
op->setOperand(i, it->second);
}
});
return newFuncOp;
}
};
@@ -245,8 +228,13 @@ static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout,
ArrayRef<int64_t> shape) {
return product(shape) / (product(layout.getThreadsPerWarp()) *
product(layout.getWarpsPerCTA()));
size_t rank = shape.size();
SmallVector<unsigned> elemsPerThreadPerDim(rank);
for (size_t i = 0; i < rank; ++i) {
unsigned t = layout.getThreadsPerWarp()[i] * layout.getWarpsPerCTA()[i];
elemsPerThreadPerDim[i] = (shape[i] + t - 1) / t;
}
return product<unsigned>(elemsPerThreadPerDim);
}
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
@@ -257,7 +245,7 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType, Type elemPtrPtrType) {
Type structType) {
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (auto v : llvm::enumerate(resultVals)) {
llvmStruct = rewriter.create<LLVM::InsertValueOp>(
@@ -513,10 +501,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
auto llElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(srcType));
auto llStruct =
getStructFromElements(loc, elems, rewriter, structTy, llElemPtrPtrTy);
auto llStruct = getStructFromElements(loc, elems, rewriter, structTy);
return llStruct;
}
@@ -529,29 +514,7 @@ struct SplatOpConversion
matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto src = op->getOperand(0);
LLVM::ConstantOp arithConstantOp;
if (src.getDefiningOp() &&
(arithConstantOp =
llvm::dyn_cast<LLVM::ConstantOp>(src.getDefiningOp()))) {
Value constant;
auto values = arithConstantOp.getValue().dyn_cast<DenseElementsAttr>();
assert(values.size() == 1);
Attribute val;
if (type::isInt(src.getType())) {
val = values.getValues<IntegerAttr>()[0];
} else if (type::isFloat(src.getType())) {
val = values.getValues<FloatAttr>()[0];
} else {
llvm::errs() << "Constant op type not supported";
return failure();
}
src = rewriter.create<LLVM::ConstantOp>(loc, val.getType(), val);
}
auto src = adaptor.src();
auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src,
getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, {llStruct});
@@ -618,12 +581,15 @@ struct StoreOpConversion
Value mask = op.mask();
Value value = op.value();
Value llPtr = adaptor.ptr(); // should be LLVM ops
Value llPtr = adaptor.ptr();
Value llMask = adaptor.mask();
Value llValue = adaptor.value();
Type valueElemTy = getTypeConverter()->convertType(
value.getType().cast<RankedTensorType>().getElementType());
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
if (!valueTy)
return failure();
Type valueElemTy =
getTypeConverter()->convertType(valueTy.getElementType());
MLIRContext *ctx = rewriter.getContext();
auto loc = op->getLoc();
@@ -662,6 +628,7 @@ struct StoreOpConversion
auto [maskLayout, maskNumElems] = getLayout(mask);
auto [valueLayout, valueNumElems] = getLayout(value);
auto ptrElems = getLLVMElems(mask, llPtr, maskLayout);
auto valueElems = getLLVMElems(value, llValue, valueLayout);
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
assert(valueElems.size() == maskElems.size());
@@ -718,17 +685,8 @@ struct StoreOpConversion
const int numVecs = ptrNumElems / vec;
for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) {
size_t in_off{};
auto ptrProducer = llPtr.getDefiningOp();
auto in_gep = llvm::dyn_cast<LLVM::GEPOp>(ptrProducer);
if (in_gep) {
auto indices = in_gep.getIndices();
auto cst = dyn_cast<LLVM::ConstantOp>(indices.front().getDefiningOp());
in_off =
cst ? cst.getValue().dyn_cast<IntegerAttr>().getInt() * dtsize : 0;
ptr = cst ? in_gep.getBase() : in_gep;
}
// TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0;
// pack sub-words (< 32/64bits) into words
// each load has width min(nbits*vec, 32/64)
@@ -747,7 +705,7 @@ struct StoreOpConversion
const bool hasL2EvictPolicy = false;
PtxIOInstr asmStoreInstr("st");
asmStoreInstr.predicate(llMask, "b");
asmStoreInstr.predicate(maskElems[vecIdx], "b");
asmStoreInstr.global().v(width).b(nWords);
llvm::SmallVector<std::string> asmArgs;
@@ -755,7 +713,8 @@ struct StoreOpConversion
Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = VectorType::get(wordNElems, valueElemTy);
auto *asmAddr = asmStoreInstr.newAddrOperand(llPtr, "l", in_off);
auto *asmAddr =
asmStoreInstr.newAddrOperand(ptrElems[vecIdx], "l", in_off);
auto *asmArgList = asmStoreInstr.newList();
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
// llWord is a width-len composition
@@ -800,9 +759,8 @@ struct StoreOpConversion
LLVM::AsmDialect::AD_ATT), // asm_dialect
ArrayAttr::get(ctx, {}) // operand_attrs
);
rewriter.replaceOp(op, inlineAsm.getRes());
}
rewriter.eraseOp(op);
return success();
}
@@ -1135,6 +1093,10 @@ struct LoadOpConversion
// finally call inline ASM
// ---
SmallVector<Value> args = {pred, ptr};
for (Value v : others) {
args.push_back(v);
}
// TODO: if (has_l2_evict_policy)
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
@@ -1177,6 +1139,95 @@ struct LoadOpConversion
}
};
struct GetProgramIdOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::GetProgramIdOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x);
auto llvmIndexTy = getTypeConverter()->getIndexType();
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
return success();
}
};
struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::GEPOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::GEPOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
auto resultLayout =
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter);
auto offsets =
getElementsFromStruct(loc, adaptor.offset(), elems, rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] =
rewriter.create<LLVM::GEPOp>(loc, elemTy, ptrs[i], offsets[i]);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
template <typename SourceOp, typename DestOp>
class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
if (!resultTy)
return failure();
Location loc = op->getLoc();
auto resultLayout = resultTy.getEncoding()
.template dyn_cast<TritonGPUBlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto lhss =
this->getElementsFromStruct(loc, adaptor.getLhs(), elems, rewriter);
auto rhss =
this->getElementsFromStruct(loc, adaptor.getRhs(), elems, rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = rewriter.create<DestOp>(loc, elemTy, lhss[i], rhss[i]);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
public:
using TypeConverter::convertType;
@@ -1221,14 +1272,20 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &analysis,
PatternBenefit benefit = 1) {
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
patterns.add<GEPOpConversion>(typeConverter, benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<LoadOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
patterns.add<ViewOpConversion>(typeConverter, benefit);
}