[BACKEND] Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp & GEPOp and bugfix for SplatOp, StoreOp, FuncOp (#60)
Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp, GEPOp and bugfix for SplatOp, StoreOp, FuncOp Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
@@ -197,6 +197,7 @@ target_link_libraries(triton
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
MLIRExecutionEngine
|
||||
MLIRNVVMToLLVMIRTranslation
|
||||
)
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
|
@@ -54,5 +54,6 @@ target_link_libraries(triton-translate PRIVATE
|
||||
MLIRExecutionEngine
|
||||
MLIRTransformUtils
|
||||
MLIRLLVMToLLVMIRTranslation
|
||||
MLIRNVVMToLLVMIRTranslation
|
||||
)
|
||||
mlir_check_all_link_libraries(triton-translate)
|
||||
|
@@ -61,7 +61,7 @@ std::string PtxInstr::Operand::dump() const {
|
||||
if (repr)
|
||||
return repr(idx);
|
||||
if (!isList())
|
||||
return llvm::formatv("%{0}", idx);
|
||||
return llvm::formatv("${0}", idx);
|
||||
llvm::SmallVector<std::string> oprs;
|
||||
for (auto *opr : list)
|
||||
oprs.push_back(opr->dump());
|
||||
@@ -72,7 +72,7 @@ PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr,
|
||||
StringRef constraint, int off) {
|
||||
auto *opr = newOperand(addr, constraint);
|
||||
opr->repr = [off](int idx) -> std::string {
|
||||
return llvm::formatv("[ %{0} + {1} ]", idx, off);
|
||||
return llvm::formatv("[ ${0} + {1} ]", idx, off);
|
||||
};
|
||||
|
||||
return opr;
|
||||
|
@@ -46,20 +46,10 @@ template <typename Int> size_t product(llvm::ArrayRef<Int> arr) {
|
||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||
}
|
||||
|
||||
// The following code are borrowed from mlir project including the following
|
||||
// functions or classes:
|
||||
// - filterFuncAttributes
|
||||
// - ConvertOpToLLVMPattern
|
||||
// - FuncOpConversion
|
||||
//
|
||||
// The code are hidden in the CPP files in MLIR repo, and we can't call them
|
||||
// directly. I found such code snippets are refactored and added to LLVMCommon
|
||||
// in the latest MLIR code, but the v14.0.0 version currentlly used in Triton
|
||||
// doesn't contain the code.
|
||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
||||
// since it is not exposed on header files in mlir v14
|
||||
// TODO(Superjomn) Remove the code when mlir v15.0 is included.
|
||||
//
|
||||
// The original code:
|
||||
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L219
|
||||
// All the rights are reserved by LLVM community.
|
||||
|
||||
/// Only retain those attributes that are not constructed by
|
||||
@@ -79,6 +69,12 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function for wrapping all attributes into a single DictionaryAttr
|
||||
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
|
||||
return DictionaryAttr::get(
|
||||
b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs));
|
||||
}
|
||||
|
||||
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
||||
protected:
|
||||
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
|
||||
@@ -90,25 +86,34 @@ protected:
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Convert the original function arguments. They are converted using the
|
||||
// LLVMTypeConverter provided to this legalization pattern.
|
||||
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
|
||||
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
|
||||
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
||||
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
||||
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
||||
if (!llvmType)
|
||||
return nullptr;
|
||||
|
||||
// Propagate argument attributes to all converted arguments obtained after
|
||||
// converting a given original argument.
|
||||
// Propagate argument/result attributes to all converted arguments/result
|
||||
// obtained after converting a given original argument/result.
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
|
||||
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
|
||||
attributes);
|
||||
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
|
||||
assert(!resAttrDicts.empty() && "expected array to be non-empty");
|
||||
auto newResAttrDicts =
|
||||
(funcOp.getNumResults() == 1)
|
||||
? resAttrDicts
|
||||
: rewriter.getArrayAttr(
|
||||
{wrapAsStructAttrs(rewriter, resAttrDicts)});
|
||||
attributes.push_back(rewriter.getNamedAttr(
|
||||
FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
|
||||
}
|
||||
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
|
||||
SmallVector<Attribute, 4> newArgAttrs(
|
||||
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
|
||||
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
|
||||
auto mapping = result.getInputMapping(i);
|
||||
assert(mapping.hasValue() &&
|
||||
"unexpected deletion of function argument");
|
||||
assert(mapping && "unexpected deletion of function argument");
|
||||
for (size_t j = 0; j < mapping->size; ++j)
|
||||
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
|
||||
}
|
||||
@@ -136,37 +141,15 @@ protected:
|
||||
}
|
||||
linkage = attr.getLinkage();
|
||||
}
|
||||
|
||||
auto oldArgs = funcOp.getArguments();
|
||||
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||
/*dsoLocal*/ false, attributes);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
|
||||
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
|
||||
&result)))
|
||||
return nullptr;
|
||||
|
||||
// Convert argument
|
||||
llvm::DenseMap<Value, Value> argMap;
|
||||
for (int i = 0, n = funcOp.getNumArguments(); i < n; i++) {
|
||||
Value oldArg = oldArgs[i];
|
||||
Value newArg = newFuncOp.getArgument(i);
|
||||
argMap.try_emplace(oldArg, newArg);
|
||||
}
|
||||
|
||||
newFuncOp.getBody().walk([&](Operation *op) {
|
||||
// Convert the function argument types, e.g, from !tt.ptr<fp16> to
|
||||
// ptr<fp16>
|
||||
for (int i = 0; i < op->getNumOperands(); i++) {
|
||||
auto arg = op->getOperand(i);
|
||||
auto it = argMap.find(arg);
|
||||
if (it != argMap.end())
|
||||
op->setOperand(i, it->second);
|
||||
}
|
||||
});
|
||||
|
||||
return newFuncOp;
|
||||
}
|
||||
};
|
||||
@@ -245,8 +228,13 @@ static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
|
||||
|
||||
static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout,
|
||||
ArrayRef<int64_t> shape) {
|
||||
return product(shape) / (product(layout.getThreadsPerWarp()) *
|
||||
product(layout.getWarpsPerCTA()));
|
||||
size_t rank = shape.size();
|
||||
SmallVector<unsigned> elemsPerThreadPerDim(rank);
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
unsigned t = layout.getThreadsPerWarp()[i] * layout.getWarpsPerCTA()[i];
|
||||
elemsPerThreadPerDim[i] = (shape[i] + t - 1) / t;
|
||||
}
|
||||
return product<unsigned>(elemsPerThreadPerDim);
|
||||
}
|
||||
|
||||
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
@@ -257,7 +245,7 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
|
||||
Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type structType, Type elemPtrPtrType) {
|
||||
Type structType) {
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
for (auto v : llvm::enumerate(resultVals)) {
|
||||
llvmStruct = rewriter.create<LLVM::InsertValueOp>(
|
||||
@@ -513,10 +501,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
|
||||
auto llElemPtrPtrTy =
|
||||
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(srcType));
|
||||
auto llStruct =
|
||||
getStructFromElements(loc, elems, rewriter, structTy, llElemPtrPtrTy);
|
||||
auto llStruct = getStructFromElements(loc, elems, rewriter, structTy);
|
||||
return llStruct;
|
||||
}
|
||||
|
||||
@@ -529,29 +514,7 @@ struct SplatOpConversion
|
||||
matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto src = op->getOperand(0);
|
||||
|
||||
LLVM::ConstantOp arithConstantOp;
|
||||
if (src.getDefiningOp() &&
|
||||
(arithConstantOp =
|
||||
llvm::dyn_cast<LLVM::ConstantOp>(src.getDefiningOp()))) {
|
||||
Value constant;
|
||||
auto values = arithConstantOp.getValue().dyn_cast<DenseElementsAttr>();
|
||||
|
||||
assert(values.size() == 1);
|
||||
Attribute val;
|
||||
if (type::isInt(src.getType())) {
|
||||
val = values.getValues<IntegerAttr>()[0];
|
||||
} else if (type::isFloat(src.getType())) {
|
||||
val = values.getValues<FloatAttr>()[0];
|
||||
} else {
|
||||
llvm::errs() << "Constant op type not supported";
|
||||
return failure();
|
||||
}
|
||||
|
||||
src = rewriter.create<LLVM::ConstantOp>(loc, val.getType(), val);
|
||||
}
|
||||
|
||||
auto src = adaptor.src();
|
||||
auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src,
|
||||
getTypeConverter(), rewriter, loc);
|
||||
rewriter.replaceOp(op, {llStruct});
|
||||
@@ -618,12 +581,15 @@ struct StoreOpConversion
|
||||
Value mask = op.mask();
|
||||
Value value = op.value();
|
||||
|
||||
Value llPtr = adaptor.ptr(); // should be LLVM ops
|
||||
Value llPtr = adaptor.ptr();
|
||||
Value llMask = adaptor.mask();
|
||||
Value llValue = adaptor.value();
|
||||
|
||||
Type valueElemTy = getTypeConverter()->convertType(
|
||||
value.getType().cast<RankedTensorType>().getElementType());
|
||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy)
|
||||
return failure();
|
||||
Type valueElemTy =
|
||||
getTypeConverter()->convertType(valueTy.getElementType());
|
||||
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
auto loc = op->getLoc();
|
||||
@@ -662,6 +628,7 @@ struct StoreOpConversion
|
||||
auto [maskLayout, maskNumElems] = getLayout(mask);
|
||||
auto [valueLayout, valueNumElems] = getLayout(value);
|
||||
|
||||
auto ptrElems = getLLVMElems(mask, llPtr, maskLayout);
|
||||
auto valueElems = getLLVMElems(value, llValue, valueLayout);
|
||||
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
@@ -718,17 +685,8 @@ struct StoreOpConversion
|
||||
const int numVecs = ptrNumElems / vec;
|
||||
for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) {
|
||||
|
||||
size_t in_off{};
|
||||
auto ptrProducer = llPtr.getDefiningOp();
|
||||
auto in_gep = llvm::dyn_cast<LLVM::GEPOp>(ptrProducer);
|
||||
|
||||
if (in_gep) {
|
||||
auto indices = in_gep.getIndices();
|
||||
auto cst = dyn_cast<LLVM::ConstantOp>(indices.front().getDefiningOp());
|
||||
in_off =
|
||||
cst ? cst.getValue().dyn_cast<IntegerAttr>().getInt() * dtsize : 0;
|
||||
ptr = cst ? in_gep.getBase() : in_gep;
|
||||
}
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
// pack sub-words (< 32/64bits) into words
|
||||
// each load has width min(nbits*vec, 32/64)
|
||||
@@ -747,7 +705,7 @@ struct StoreOpConversion
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PtxIOInstr asmStoreInstr("st");
|
||||
asmStoreInstr.predicate(llMask, "b");
|
||||
asmStoreInstr.predicate(maskElems[vecIdx], "b");
|
||||
asmStoreInstr.global().v(width).b(nWords);
|
||||
|
||||
llvm::SmallVector<std::string> asmArgs;
|
||||
@@ -755,7 +713,8 @@ struct StoreOpConversion
|
||||
Type valArgTy = IntegerType::get(ctx, width);
|
||||
auto wordTy = VectorType::get(wordNElems, valueElemTy);
|
||||
|
||||
auto *asmAddr = asmStoreInstr.newAddrOperand(llPtr, "l", in_off);
|
||||
auto *asmAddr =
|
||||
asmStoreInstr.newAddrOperand(ptrElems[vecIdx], "l", in_off);
|
||||
auto *asmArgList = asmStoreInstr.newList();
|
||||
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
|
||||
// llWord is a width-len composition
|
||||
@@ -800,9 +759,8 @@ struct StoreOpConversion
|
||||
LLVM::AsmDialect::AD_ATT), // asm_dialect
|
||||
ArrayAttr::get(ctx, {}) // operand_attrs
|
||||
);
|
||||
|
||||
rewriter.replaceOp(op, inlineAsm.getRes());
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1135,6 +1093,10 @@ struct LoadOpConversion
|
||||
// finally call inline ASM
|
||||
// ---
|
||||
SmallVector<Value> args = {pred, ptr};
|
||||
for (Value v : others) {
|
||||
args.push_back(v);
|
||||
}
|
||||
// TODO: if (has_l2_evict_policy)
|
||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
LLVM::AsmDialect::AD_ATT);
|
||||
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
|
||||
@@ -1177,6 +1139,95 @@ struct LoadOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct GetProgramIdOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::GetProgramIdOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x);
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
|
||||
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::GEPOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::GEPOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
|
||||
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter);
|
||||
auto offsets =
|
||||
getElementsFromStruct(loc, adaptor.offset(), elems, rewriter);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] =
|
||||
rewriter.create<LLVM::GEPOp>(loc, elemTy, ptrs[i], offsets[i]);
|
||||
}
|
||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp, typename DestOp>
|
||||
class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
|
||||
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
|
||||
if (!resultTy)
|
||||
return failure();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
auto resultLayout = resultTy.getEncoding()
|
||||
.template dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
auto lhss =
|
||||
this->getElementsFromStruct(loc, adaptor.getLhs(), elems, rewriter);
|
||||
auto rhss =
|
||||
this->getElementsFromStruct(loc, adaptor.getRhs(), elems, rewriter);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = rewriter.create<DestOp>(loc, elemTy, lhss[i], rhss[i]);
|
||||
}
|
||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
|
||||
public:
|
||||
using TypeConverter::convertType;
|
||||
@@ -1221,14 +1272,20 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &analysis,
|
||||
PatternBenefit benefit = 1) {
|
||||
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
|
||||
patterns.add<SplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
|
||||
patterns.add<GEPOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<SplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
|
||||
patterns.add<ViewOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
|
@@ -8,8 +8,10 @@
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "llvm/IR/Constants.h"
|
||||
@@ -82,7 +84,8 @@ std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
|
||||
auto context = module->getContext();
|
||||
DialectRegistry registry;
|
||||
registerLLVMDialectTranslation(registry);
|
||||
mlir::registerLLVMDialectTranslation(registry);
|
||||
mlir::registerNVVMDialectTranslation(registry);
|
||||
context->appendDialectRegistry(registry);
|
||||
|
||||
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
|
||||
@@ -123,6 +126,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass());
|
||||
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
llvm::errs() << "Pass execution failed";
|
||||
|
@@ -1504,7 +1504,7 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUVerifier());
|
||||
})
|
||||
.def("triton_gpu_to_llvm", [](mlir::PassManager &self) {
|
||||
.def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||
});
|
||||
}
|
||||
|
30
python/test/vecadd_no_scf.py
Normal file
30
python/test/vecadd_no_scf.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
NUM_WARPS = 4
|
||||
|
||||
# triton kernel
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(x_ptr, stride_xn,
|
||||
y_ptr, stride_yn,
|
||||
z_ptr, stride_zn,
|
||||
BLOCK_SIZE_N: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
x_ptrs = x_ptr + offset
|
||||
y_ptrs = y_ptr + offset
|
||||
x = tl.load(x_ptrs)
|
||||
y = tl.load(y_ptrs)
|
||||
z = x + y
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
|
||||
|
||||
print(ret)
|
||||
|
||||
# TODO: base class for python end2end tests,
|
||||
# runtime execution, correctness comparison etc.
|
@@ -19,6 +19,8 @@ func @test_splat(%ptr: !tt.ptr<f32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%a = arith.constant 1.0 : f32
|
||||
@@ -27,9 +29,8 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
||||
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
||||
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
||||
|
||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@%0 st.global.v32.b1 [ %1 + 0 ], { %2 };",
|
||||
// CHECK: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.struct<(i1, i1)>, !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>, i32) -> !llvm.struct<()>
|
||||
|
||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.v32.b1 [ $1 + 0 ], { $2 };",
|
||||
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>
|
||||
|
||||
return
|
||||
|
@@ -112,28 +112,81 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
}
|
||||
}
|
||||
|
||||
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
// #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||
// #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// func @debut_kernel(%lb : index, %A : !tt.ptr<f32>, %B : !tt.ptr<f32>, %C : !tt.ptr<f32>) {
|
||||
// %cst = arith.constant dense<true> : tensor<256xi1, #blocked0>
|
||||
// %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
||||
// %cst_1 = arith.constant dense<true> : tensor<1024x256xi1, #blocked1>
|
||||
// %cst_2 = arith.constant dense<true> : tensor<256x2048xi1, #blocked2>
|
||||
// %a_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
// %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
// %4 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<1x256xf32,#blocked1>
|
||||
// %5 = tt.broadcast %4 : (tensor<1x256xf32,#blocked1>) -> tensor<1024x256xf32, #blocked1>
|
||||
// %6 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
|
||||
// %7 = tt.broadcast %6 : (tensor<256x1xf32,#blocked2>) -> tensor<256x2048xf32, #blocked2>
|
||||
// %b_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<1024x256x!tt.ptr<f32>, #blocked1>
|
||||
// %c_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<256x2048x!tt.ptr<f32>, #blocked2>
|
||||
// tt.store %b_ptr_init, %5, %cst_1, : tensor<1024x256xf32, #blocked1>
|
||||
// tt.store %c_ptr_init, %7, %cst_2, : tensor<256x2048xf32, #blocked2>
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_addf
|
||||
func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
|
||||
// CHECK: llvm.fadd
|
||||
// CHECK: llvm.fadd
|
||||
%1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_addi
|
||||
func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.add
|
||||
%1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_program_id
|
||||
func @basic_program_id() {
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_gep
|
||||
func @basic_gep(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK: llvm.getelementptr
|
||||
// CHECK: llvm.getelementptr
|
||||
%0 = tt.getelementptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: basic_splat
|
||||
func @basic_splat(%ptr: !tt.ptr<f32>) {
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: llvm.insertvalue
|
||||
%0 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>,#blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_store
|
||||
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user