[BACKEND] Add LLVM-translation for store and splat ops (#47)

This commit is contained in:
Yan Chunwei
2022-08-15 15:46:37 +08:00
committed by GitHub
parent 993ba7035a
commit 95bbac41e7
8 changed files with 815 additions and 57 deletions

View File

@@ -1,5 +1,6 @@
add_mlir_conversion_library(TritonGPUToLLVM
TritonGPUToLLVM.cpp
PtxAsmFormat.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM

View File

@@ -0,0 +1,81 @@
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
namespace triton {
std::string strJoin(llvm::ArrayRef<std::string> strs,
llvm::StringRef delimiter) {
std::string osStr;
llvm::raw_string_ostream os(osStr);
for (size_t i = 0; !strs.empty() && i < strs.size() - 1; i++)
os << strs[i] << delimiter;
if (!strs.empty())
os << strs.back();
os.flush();
return osStr;
}
std::string PtxInstr::dump() const {
std::string osStr;
llvm::raw_string_ostream os(osStr);
if (pred)
os << "@" << pred->dump() << " ";
std::string instrRepr = strJoin(instrParts, ".");
llvm::SmallVector<std::string, 4> argReprs;
for (auto *arg : argsInOrder) {
argReprs.push_back(arg->dump());
}
std::string argsRepr = strJoin(argReprs, ", ");
os << instrRepr << " " << argsRepr << ";";
os.flush();
return osStr;
}
llvm::SmallVector<PtxInstr::Operand *, 4> PtxInstr::getArgList() const {
SmallVector<Operand *> args;
for (auto *arg : argsInOrder) {
if (arg->isList())
args.insert(args.end(), arg->list.begin(), arg->list.end());
else
args.push_back(arg);
}
return args;
}
PtxInstr::Operand *
PtxInstr::newOperand(mlir::Value value, StringRef constraint,
std::function<std::string(int)> formater) {
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
auto *opr = argArchive.back().get();
opr->repr = formater;
opr->idx = oprCounter++;
return opr;
}
std::string PtxInstr::Operand::dump() const {
if (repr)
return repr(idx);
if (!isList())
return llvm::formatv("%{0}", idx);
llvm::SmallVector<std::string> oprs;
for (auto *opr : list)
oprs.push_back(opr->dump());
return "{ " + strJoin(oprs, ", ") + " }";
}
PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr,
StringRef constraint, int off) {
auto *opr = newOperand(addr, constraint);
opr->repr = [off](int idx) -> std::string {
return llvm::formatv("[ %{0} + {1} ]", idx, off);
};
return opr;
}
} // namespace triton
} // namespace mlir

View File

@@ -1,5 +1,4 @@
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
@@ -10,10 +9,17 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
#include <memory>
#include <numeric>
#include <string>
using namespace mlir;
using namespace mlir::triton;
@@ -31,8 +37,15 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
namespace {
namespace type = mlir::triton::type;
class TritonGPUToLLVMTypeConverter;
// TODO(Superjomn) Move to somewhere general utilities locates.
template <typename Int> size_t product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}
// The following code are borrowed from mlir project including the following
// functions or classes:
// - filterFuncAttributes
@@ -81,7 +94,6 @@ protected:
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = getTypeConverter()->convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
assert(llvmType);
if (!llvmType)
return nullptr;
@@ -124,6 +136,8 @@ protected:
}
linkage = attr.getLinkage();
}
auto oldArgs = funcOp.getArguments();
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, attributes);
@@ -134,6 +148,25 @@ protected:
&result)))
return nullptr;
// Convert argument
llvm::DenseMap<Value, Value> argMap;
for (int i = 0, n = funcOp.getNumArguments(); i < n; i++) {
Value oldArg = oldArgs[i];
Value newArg = newFuncOp.getArgument(i);
argMap.try_emplace(oldArg, newArg);
}
newFuncOp.getBody().walk([&](Operation *op) {
// Convert the function argument types, e.g, from !tt.ptr<fp16> to
// ptr<fp16>
for (int i = 0; i < op->getNumOperands(); i++) {
auto arg = op->getOperand(i);
auto it = argMap.find(arg);
if (it != argMap.end())
op->setOperand(i, it->second);
}
});
return newFuncOp;
}
};
@@ -143,8 +176,9 @@ protected:
/// information.
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps)
: FuncOpConversionBase(converter), NumWarps(numWarps) {}
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
PatternBenefit benefit)
: FuncOpConversionBase(converter, benefit), NumWarps(numWarps) {}
LogicalResult
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
@@ -154,11 +188,11 @@ struct FuncOpConversion : public FuncOpConversionBase {
return failure();
auto ctx = funcOp->getContext();
auto i32 = IntegerType::get(ctx, 32);
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr(NVVMMetadataField::MaxNTid,
rewriter.getIntegerAttr(i32, 32 * NumWarps));
newFuncOp->setAttr(
NVVMMetadataField::MaxNTid,
rewriter.getIntegerAttr(type::i32Ty(ctx), 32 * NumWarps));
rewriter.eraseOp(funcOp);
return success();
@@ -190,22 +224,47 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
}
};
// Extract numWarps information from TritonGPU module, return 0 if failed.
// This is a naive implementation, it assumes that all the blocked layout should
// have the same numWarps setting in a module, it just find a blocked layout
// encoding and return the warpsPerCTA field.
int extractNumWarps(mlir::ModuleOp module) {
int numWarps{};
if (module->hasAttr(AttrNumWarpsName))
numWarps = module->getAttr(AttrNumWarpsName)
.dyn_cast<IntegerAttr>()
.getValue()
.getZExtValue();
else
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
ArrayRef<int64_t> shape) {
assert(multidim_index.size() == shape.size());
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
int64_t rank = shape.size();
int64_t acc_mul = 1;
for (int64_t i = 1; i < rank; ++i) {
acc_mul *= shape[i];
}
int64_t linear_index = 0;
for (int64_t i = 0; i < rank; ++i) {
linear_index += multidim_index[i] * acc_mul;
if (i != (rank - 1)) {
acc_mul = acc_mul / shape[i + 1];
}
}
return linear_index;
}
return numWarps;
static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout,
ArrayRef<int64_t> shape) {
return product(shape) / (product(layout.getThreadsPerWarp()) *
product(layout.getWarpsPerCTA()));
}
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
}
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType, Type elemPtrPtrType) {
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (auto v : llvm::enumerate(resultVals)) {
llvmStruct = rewriter.create<LLVM::InsertValueOp>(
loc, structType, llvmStruct, v.value(),
rewriter.getI64ArrayAttr(v.index()));
}
return llvmStruct;
}
template <typename T>
@@ -247,24 +306,6 @@ static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
return linear_index;
}
static unsigned getElemsPerThread(const TritonGPUBlockedEncodingAttr &layout,
ArrayRef<int64_t> shape) {
unsigned elems = 1;
size_t rank = shape.size();
assert(rank == layout.getThreadsPerWarp().size());
for (size_t d = 0; d < rank; ++d) {
elems *=
shape[d] / (layout.getThreadsPerWarp()[d] * layout.getWarpsPerCTA()[d]);
}
return elems;
}
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(resultType, value));
}
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp> {
@@ -443,6 +484,358 @@ public:
}
};
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
// LLVM::StructType value.
//
// @elemType: the element type in operand.
// @resType: the return type of the Splat-like op.
// @constVal: a LLVM::ConstantOp or other scalar value.
Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto layout = tensorTy.getEncoding().cast<TritonGPUBlockedEncodingAttr>();
auto srcType = typeConverter->convertType(elemType);
auto llSrc = rewriter.create<LLVM::BitcastOp>(loc, srcType, constVal);
auto numElems = layout.getSizePerThread();
size_t totalElems =
std::accumulate(tensorTy.getShape().begin(), tensorTy.getShape().end(), 1,
std::multiplies<>{});
size_t numThreads =
product(layout.getWarpsPerCTA()) * product(layout.getThreadsPerWarp());
// TODO(Superjomn) add numElemsPerThread to the layout encodings.
size_t numElemsPerThread = totalElems / numThreads;
llvm::SmallVector<Value, 4> elems(numElemsPerThread, llSrc);
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
auto llElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(srcType));
auto llStruct =
getStructFromElements(loc, elems, rewriter, structTy, llElemPtrPtrTy);
return llStruct;
}
struct SplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::SplatOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::SplatOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto src = op->getOperand(0);
LLVM::ConstantOp arithConstantOp;
if (src.getDefiningOp() &&
(arithConstantOp =
llvm::dyn_cast<LLVM::ConstantOp>(src.getDefiningOp()))) {
Value constant;
auto values = arithConstantOp.getValue().dyn_cast<DenseElementsAttr>();
assert(values.size() == 1);
Attribute val;
if (type::isInt(src.getType())) {
val = values.getValues<IntegerAttr>()[0];
} else if (type::isFloat(src.getType())) {
val = values.getValues<FloatAttr>()[0];
} else {
llvm::errs() << "Constant op type not supported";
return failure();
}
src = rewriter.create<LLVM::ConstantOp>(loc, val.getType(), val);
}
auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src,
getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, {llStruct});
return success();
}
};
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), the
// logic is the same as triton::SplatOp, so the underlying implementation is
// reused.
struct ArithConstantSplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> {
using ConvertTritonGPUOpToLLVMPattern<
arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto value = op.getValue();
if (!value.dyn_cast<SplatElementsAttr>())
return failure();
auto loc = op->getLoc();
LLVM::ConstantOp arithConstantOp;
auto values = op.getValue().dyn_cast<SplatElementsAttr>();
auto elemType = values.getElementType();
Attribute val;
if (type::isInt(elemType)) {
val = values.getValues<IntegerAttr>()[0];
} else if (type::isFloat(elemType)) {
val = values.getValues<FloatAttr>()[0];
} else {
llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: "
<< value.getType() << "\n";
return failure();
}
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
auto llStruct = convertSplatLikeOp(elemType, op.getType(), constOp,
getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, llStruct);
return success();
}
};
struct StoreOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
StoreOpConversion(LLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
AxisAnalysisPass(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value ptr = op.ptr();
Value mask = op.mask();
Value value = op.value();
Value llPtr = adaptor.ptr(); // should be LLVM ops
Value llMask = adaptor.mask();
Value llValue = adaptor.value();
Type valueElemTy = getTypeConverter()->convertType(
value.getType().cast<RankedTensorType>().getElementType());
MLIRContext *ctx = rewriter.getContext();
auto loc = op->getLoc();
auto getLLVMElems = [&](Value value, Value llValue,
const TritonGPUBlockedEncodingAttr &layout)
-> SmallVector<Value, 4> {
auto ty = value.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
unsigned valueElems = getElemsPerThread(layout, shape);
auto llvmElemTy = getTypeConverter()->convertType(ty.getElementType());
auto llvmElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
auto valueVals =
getElementsFromStruct(loc, llValue, valueElems, rewriter);
return valueVals;
};
auto getLayout =
[&](Value val) -> std::tuple<TritonGPUBlockedEncodingAttr, unsigned> {
auto ty = val.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
auto layout = ty.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
unsigned valueElems = getElemsPerThread(layout, shape);
return std::make_tuple(layout, valueElems);
};
auto [ptrLayout, ptrNumElems] = getLayout(ptr);
auto [maskLayout, maskNumElems] = getLayout(mask);
auto [valueLayout, valueNumElems] = getLayout(value);
auto valueElems = getLLVMElems(value, llValue, valueLayout);
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
assert(valueElems.size() == maskElems.size());
auto getAlign =
[this](Value val,
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
auto axisInfo = getAxisInfo(val);
assert(axisInfo.hasValue());
auto order = layout.getOrder();
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
unsigned maxContig = axisInfo->getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
};
// get align
auto getVec = [this, &getAlign](
Value val,
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
auto axisInfo = getAxisInfo(val);
auto contig = axisInfo->getContiguity();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = layout.getOrder();
unsigned align = getAlign(val, layout);
assert(!order.empty());
// Is this right?
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
unsigned vec = std::min(align, contigPerThread);
// TODO(Superjomn) Consider the is_mma_first_row in the legacy code
bool isMMAFirstRow = false;
if (isMMAFirstRow)
vec = std::min<size_t>(2, align);
return vec;
};
// Determine the vectorization size
size_t vec = getVec(ptr, ptrLayout);
const size_t dtsize = value.getType()
.cast<RankedTensorType>()
.getElementType()
.getIntOrFloatBitWidth() /
8;
const size_t valueElemNbits = dtsize * 8;
const int numVecs = ptrNumElems / vec;
for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) {
size_t in_off{};
auto ptrProducer = llPtr.getDefiningOp();
auto in_gep = llvm::dyn_cast<LLVM::GEPOp>(ptrProducer);
if (in_gep) {
auto indices = in_gep.getIndices();
auto cst = dyn_cast<LLVM::ConstantOp>(indices.front().getDefiningOp());
in_off =
cst ? cst.getValue().dyn_cast<IntegerAttr>().getInt() * dtsize : 0;
ptr = cst ? in_gep.getBase() : in_gep;
}
// pack sub-words (< 32/64bits) into words
// each load has width min(nbits*vec, 32/64)
// and there are (nbits * vec)/width of them
const int maxWordWidth = std::max<int>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width);
const int wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits;
assert(wordNElems * nWords * numVecs == valueElems.size());
// TODO(Superjomn) Add cache policy to store.
// TODO(Superjomn) deal with cache policy.
const bool hasL2EvictPolicy = false;
PtxIOInstr asmStoreInstr("st");
asmStoreInstr.predicate(llMask, "b");
asmStoreInstr.global().v(width).b(nWords);
llvm::SmallVector<std::string> asmArgs;
Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = VectorType::get(wordNElems, valueElemTy);
auto *asmAddr = asmStoreInstr.newAddrOperand(llPtr, "l", in_off);
auto *asmArgList = asmStoreInstr.newList();
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
// llWord is a width-len composition
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
// Insert each value element to the composition
for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) {
Value elem =
valueElems[vecIdx * vecNElems + wordIdx * wordNElems + elemIdx];
if (elem.getType().isInteger(1))
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = rewriter.create<LLVM::BitcastOp>(loc, valueElemTy, elem);
llWord = rewriter.create<LLVM::InsertElementOp>(
loc, wordTy, llWord, elem,
rewriter.create<LLVM::ConstantOp>(
loc, type::u32Ty(ctx),
IntegerAttr::get(type::u32Ty(ctx), elemIdx)));
}
llWord = rewriter.create<LLVM::BitcastOp>(loc, valArgTy, llWord);
std::string constraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
asmArgList->listAppend(asmStoreInstr.newOperand(llWord, constraint));
}
asmStoreInstr.addOperand(asmAddr);
asmStoreInstr.addOperand(asmArgList);
llvm::SmallVector<Type, 4> argTys({mask.getType(), ptr.getType()});
for (int i = 0; i < nWords; i++)
argTys.push_back(valArgTy);
auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {});
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
loc, ASMReturnTy, asmStoreInstr.getAllMlirArgs(), // operands
asmStoreInstr.dump(), // asm_string
asmStoreInstr.getConstrains(), // constraints
// TODO(Superjomn) determine the side effect.
true, // has_side_effects
false, // is_align_stack
LLVM::AsmDialectAttr::get(ctx,
LLVM::AsmDialect::AD_ATT), // asm_dialect
ArrayAttr::get(ctx, {}) // operand_attrs
);
rewriter.replaceOp(op, inlineAsm.getRes());
}
return success();
}
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
return it->getValue();
}
return llvm::Optional<AxisInfo>{};
}
private:
AxisInfoAnalysis &AxisAnalysisPass;
};
// Extract numWarps information from TritonGPU module, return 0 if failed.
// This is a naive implementation, it assumes that all the blocked layout should
// have the same numWarps setting in a module, it just find a blocked layout
// encoding and return the warpsPerCTA field.
int extractNumWarps(mlir::ModuleOp module) {
int numWarps{};
if (module->hasAttr(AttrNumWarpsName))
numWarps = module->getAttr(AttrNumWarpsName)
.dyn_cast<IntegerAttr>()
.getValue()
.getZExtValue();
else
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return numWarps;
}
struct BroadcastOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
using ConvertTritonGPUOpToLLVMPattern<
@@ -647,8 +1040,7 @@ struct LoadOpConversion
for (size_t i = 0; i < numElems; i += vecWidth) {
Value ptr = ptrVals[i];
// TODO: Handle the optimization if ptr is from GEP and the idx is
// constant
// This should be a canonicalization pattern in LLVM Dialect
// constant. This should be a canonicalization pattern in LLVM Dialect
unsigned in_off = 0;
Value pred = maskVals[i];
@@ -826,13 +1218,18 @@ public:
};
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps) {
patterns.add<BroadcastOpConversion>(typeConverter);
patterns.add<FuncOpConversion>(typeConverter, numWarps);
patterns.add<LoadOpConversion>(typeConverter);
patterns.add<MakeRangeOpConversion>(typeConverter);
patterns.add<ReturnOpConversion>(typeConverter);
patterns.add<ViewOpConversion>(typeConverter);
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &analysis,
PatternBenefit benefit = 1) {
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<LoadOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
patterns.add<ViewOpConversion>(typeConverter, benefit);
}
class ConvertTritonGPUToLLVM
@@ -851,20 +1248,33 @@ public:
TritonLLVMConversionTarget target(*context, typeConverter);
RewritePatternSet patterns(context);
// TODO: (goostavz) Temporarily disable this, since the lowering of
// arithmetic ops in tensor format is not complete yet.
// Add arith's patterns to help convert scalar expression to LLVM.
// mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
// patterns);
int numWarps = extractNumWarps(mod);
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
auto axisAnalysis = runAxisAnalysis(mod);
// We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community patterns.
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
*axisAnalysis, 10 /*benefit*/);
// Add arith's patterns to help convert scalar expression to LLVM.
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
}
protected:
std::unique_ptr<AxisInfoAnalysis> runAxisAnalysis(ModuleOp module) {
auto axisAnalysisPass =
std::make_unique<AxisInfoAnalysis>(module->getContext());
axisAnalysisPass->run(module);
return axisAnalysisPass;
}
};
} // namespace

View File

@@ -102,6 +102,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
return ret;