[BACKEND] Codegen bringup, index calculation of blocked_layout & support of LoadOp, BroadcastOp, ViewOp & MakeRangeOp (#38)

Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
goostavz
2022-08-15 10:58:59 +08:00
committed by GitHub
parent e5ec8e16ea
commit 993ba7035a
6 changed files with 785 additions and 22 deletions

View File

@@ -33,11 +33,13 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::StandardOpsDialect",
"mlir::gpu::GPUDialect",
"mlir::scf::SCFDialect",
"mlir::LLVM::LLVMDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"];
"mlir::triton::gpu::TritonGPUDialect",
"mlir::NVVM::NVVMDialect",
"mlir::StandardOpsDialect"];
}
#endif

View File

@@ -1,7 +1,9 @@
#ifndef TRITON_CONVERSION_PASSDETAIL_H
#define TRITON_CONVERSION_PASSDETAIL_H
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

View File

@@ -13,6 +13,9 @@ add_mlir_conversion_library(TritonGPUToLLVM
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRGPUOps
MLIRGPUToNVVMTransforms
MLIRGPUTransforms
TritonIR
TritonGPUIR
TritonGPUTransforms

View File

@@ -1,9 +1,14 @@
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -12,6 +17,9 @@
using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::TritonGPUBlockedEncodingAttr;
using ::mlir::triton::gpu::TritonGPUMmaEncodingAttr;
using ::mlir::triton::gpu::TritonGPUSharedEncodingAttr;
namespace mlir {
namespace LLVM {
@@ -23,6 +31,8 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
namespace {
class TritonGPUToLLVMTypeConverter;
// The following code are borrowed from mlir project including the following
// functions or classes:
// - filterFuncAttributes
@@ -134,17 +144,7 @@ protected:
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps)
: FuncOpConversionBase(converter), NumWarps(numWarps) {
mlir::ConvertToLLVMPattern::getTypeConverter()->addConversion(
[&](triton::PointerType type) {
return convertTritonPointerType(type);
});
}
Type convertTritonPointerType(triton::PointerType type) {
return LLVM::LLVMPointerType::get(type.getPointeeType(),
type.getAddressSpace());
}
: FuncOpConversionBase(converter), NumWarps(numWarps) {}
LogicalResult
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
@@ -172,7 +172,7 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adapter,
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
unsigned numArguments = op.getNumOperands();
@@ -208,12 +208,631 @@ int extractNumWarps(mlir::ModuleOp module) {
return numWarps;
}
} // namespace
template <typename T>
static SmallVector<T> getMultiDimIndex(T linear_index, ArrayRef<T> shape) {
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
size_t rank = shape.size();
T acc_mul = 1;
for (size_t i = 1; i < rank; ++i) {
acc_mul *= shape[i];
}
T linear_remain = linear_index;
SmallVector<T> multidim_index(rank);
for (size_t i = 0; i < rank; ++i) {
multidim_index[i] = linear_remain / acc_mul;
linear_remain = linear_remain % acc_mul;
if (i != (rank - 1)) {
acc_mul = acc_mul / shape[i + 1];
}
}
return multidim_index;
}
template <typename T>
static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
assert(multidim_index.size() == shape.size());
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
size_t rank = shape.size();
T acc_mul = 1;
for (size_t i = 1; i < rank; ++i) {
acc_mul *= shape[i];
}
T linear_index = 0;
for (size_t i = 0; i < rank; ++i) {
linear_index += multidim_index[i] * acc_mul;
if (i != (rank - 1)) {
acc_mul = acc_mul / shape[i + 1];
}
}
return linear_index;
}
static unsigned getElemsPerThread(const TritonGPUBlockedEncodingAttr &layout,
ArrayRef<int64_t> shape) {
unsigned elems = 1;
size_t rank = shape.size();
assert(rank == layout.getThreadsPerWarp().size());
for (size_t d = 0; d < rank; ++d) {
elems *=
shape[d] / (layout.getThreadsPerWarp()[d] * layout.getWarpsPerCTA()[d]);
}
return elems;
}
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(resultType, value));
}
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
SmallVector<Value, 4>
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value, 4> results(elems);
for (unsigned i = 0; i < elems; ++i) {
Type type =
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody()[i];
results[i] = rewriter.create<LLVM::ExtractValueOp>(
loc, type, llvmStruct, rewriter.getI64ArrayAttr(i));
}
return results;
}
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType) const {
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (auto v : llvm::enumerate(resultVals)) {
llvmStruct = rewriter.create<LLVM::InsertValueOp>(
loc, structType, llvmStruct, v.value(),
rewriter.getI64ArrayAttr(v.index()));
}
return llvmStruct;
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) const {
unsigned rank = shape.size();
assert(rank == order.size());
SmallVector<unsigned> reordered(rank);
for (unsigned i = 0; i < rank; ++i) {
reordered[i] = shape[order[i]];
}
return delinearize(rewriter, loc, linear, reordered);
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape) const {
unsigned rank = shape.size();
assert(rank > 0);
SmallVector<Value> multiDim(rank);
if (rank == 1) {
multiDim[0] = linear;
} else {
Value remained = linear;
for (auto &&en : llvm::enumerate(llvm::reverse(shape.drop_front()))) {
Value dimSize = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(),
en.value());
multiDim[rank - 1 - en.index()] =
rewriter.create<LLVM::URemOp>(loc, remained, dimSize);
remained = rewriter.create<LLVM::UDivOp>(loc, remained, dimSize);
}
multiDim[0] = remained;
}
return multiDim;
}
// Emit indices calculation within each ConversionPattern
// TODO: [goostavz] Double confirm the redundant indices calculations will
// be eliminated in the consequent MLIR/LLVM optimization
SmallVector<SmallVector<Value>> emitIndicesForBlockedLayout(
Location loc, ConversionPatternRewriter &b,
const TritonGPUBlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
auto cast = b.create<UnrealizedConversionCastOp>(
loc, TypeRange{llvmIndexTy},
ValueRange{b.create<::mlir::gpu::ThreadIdOp>(
loc, b.getIndexType(), ::mlir::gpu::Dimension::x)});
Value threadId = cast.getResult(0);
Value warpSize = createIndexAttrConstant(b, loc, llvmIndexTy, 32);
Value laneId = b.create<LLVM::URemOp>(loc, threadId, warpSize);
Value warpId = b.create<LLVM::UDivOp>(loc, threadId, warpSize);
auto sizePerThread = blocked_layout.getSizePerThread();
auto threadsPerWarp = blocked_layout.getThreadsPerWarp();
auto warpsPerCTA = blocked_layout.getWarpsPerCTA();
auto order = blocked_layout.getOrder();
unsigned rank = shape.size();
SmallVector<Value, 4> threadIds(rank);
// step 1, delinearize threadId to get the base index
SmallVector<Value> multiDimWarpId =
delinearize(b, loc, warpId, warpsPerCTA, order);
SmallVector<Value> multiDimThreadId =
delinearize(b, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
// multiDimBase[k] = (multiDimThreadId[k] + multiDimWarpId[k] *
// threadsPerWarp[k]) *
// sizePerThread[k];
Value threadsPerWarpK =
createIndexAttrConstant(b, loc, llvmIndexTy, threadsPerWarp[k]);
Value sizePerThreadK =
createIndexAttrConstant(b, loc, llvmIndexTy, sizePerThread[k]);
multiDimBase[k] = b.create<LLVM::MulOp>(
loc, sizePerThreadK,
b.create<LLVM::AddOp>(
loc, multiDimThreadId[k],
b.create<LLVM::MulOp>(loc, multiDimWarpId[k], threadsPerWarpK)));
}
// step 2, get offset of each element
unsigned elemsPerThread = 1;
SmallVector<SmallVector<unsigned>> offset(rank);
SmallVector<unsigned> multiDimElemsPerThread(rank);
for (unsigned k = 0; k < rank; ++k) {
multiDimElemsPerThread[k] = shape[k] / threadsPerWarp[k] / warpsPerCTA[k];
elemsPerThread *= multiDimElemsPerThread[k];
for (unsigned blockOffset = 0;
blockOffset <
shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]);
++blockOffset) {
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k];
++warpOffset) {
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
++threadOffset) {
for (unsigned elemOffset = 0; elemOffset < sizePerThread[k];
++elemOffset) {
offset[k].push_back(blockOffset * sizePerThread[k] *
threadsPerWarp[k] * warpsPerCTA[k] +
warpOffset * sizePerThread[k] *
threadsPerWarp[k] +
threadOffset * sizePerThread[k] + elemOffset);
}
}
}
}
}
// step 3, add offset to base, and reorder the sequence of indices,
// to guarantee that elems in a same sizePerThread are adjacent in
// order
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread);
unsigned accumSizePerThread =
std::accumulate(sizePerThread.begin(), sizePerThread.end(), 1,
std::multiplies<unsigned>());
SmallVector<unsigned> threadsPerDim(rank);
for (unsigned k = 0; k < rank; ++k) {
threadsPerDim[k] = shape[k] / sizePerThread[k];
}
for (unsigned n = 0; n < elemsPerThread; ++n) {
unsigned linearNanoTileId = n / accumSizePerThread;
unsigned linearElemsInNanoTileId = n % accumSizePerThread;
SmallVector<unsigned> multiDimNanoTileId =
getMultiDimIndex<unsigned>(linearNanoTileId, threadsPerDim);
SmallVector<unsigned> multiElemsInNanoTileId =
getMultiDimIndex<unsigned>(linearElemsInNanoTileId, sizePerThread);
multiDimIdx[n].resize(rank);
for (unsigned k = 0; k < rank; ++k) {
unsigned reorderedMultiDimId =
multiDimNanoTileId[k] *
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
multiElemsInNanoTileId[k];
multiDimIdx[n][k] = b.create<LLVM::AddOp>(
loc, multiDimBase[k],
createIndexAttrConstant(b, loc, llvmIndexTy,
offset[k][reorderedMultiDimId]));
}
}
return multiDimIdx;
}
};
struct BroadcastOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::BroadcastOp>::ConvertTritonGPUOpToLLVMPattern;
// Following the order of indices in the legacy code, a broadcast of:
// [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)]
// =>
// [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)]
//
// logically maps to a broadcast within a thread's scope:
// [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1),
// 1,spt(k+1)..spt(n-1)]
// =>
// [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)]
//
// regardless of the order of the layout
//
LogicalResult
matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value src = adaptor.src();
Value result = op.result();
auto srcTy = op.src().getType().cast<RankedTensorType>();
auto resultTy = result.getType().cast<RankedTensorType>();
auto srcLayout =
srcTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto resultLayout =
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
assert(srcLayout && (srcLayout == resultLayout) &&
"Unexpected layout of BroadcastOp");
auto srcShape = srcTy.getShape();
auto resultShape = resultTy.getShape();
unsigned rank = srcTy.getRank();
// TODO: [goostavz] double confirm the op semantics with Phil
assert(rank == resultTy.getRank());
SmallVector<int64_t, 4> srcLogicalShape(2 * rank);
SmallVector<int64_t, 4> resultLogicalShape(2 * rank);
SmallVector<unsigned, 2> broadcastDims;
SmallVector<int64_t, 2> broadcastSizes;
int64_t duplicates = 1;
for (unsigned d = 0; d < rank; ++d) {
int64_t numCtas = resultShape[d] / (resultLayout.getSizePerThread()[d] *
resultLayout.getThreadsPerWarp()[d] *
resultLayout.getWarpsPerCTA()[d]);
if (srcShape[d] != resultShape[d]) {
assert(srcShape[d] == 1);
broadcastDims.push_back(d);
broadcastSizes.push_back(resultShape[d]);
srcLogicalShape[d] = 1;
srcLogicalShape[d + rank] = 1;
duplicates *= resultShape[d];
} else {
srcLogicalShape[d] = numCtas;
srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
}
resultLogicalShape[d] = numCtas;
resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d];
}
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
auto elemTy = resultTy.getElementType();
auto srcVals = getElementsFromStruct(loc, src, srcElems, rewriter);
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
SmallVector<Value> resultVals(resultElems);
for (unsigned i = 0; i < srcElems; ++i) {
auto srcMultiDim = getMultiDimIndex<int64_t>(i, srcLogicalShape);
auto resultMultiDim = srcMultiDim;
for (int64_t j = 0; j < duplicates; ++j) {
auto bcastMultiDim = getMultiDimIndex<int64_t>(j, broadcastSizes);
for (auto bcastDim : llvm::enumerate(broadcastDims)) {
resultMultiDim[bcastDim.value()] = bcastMultiDim[bcastDim.index()];
}
auto resultLinearIndex =
getLinearIndex<int64_t>(resultMultiDim, resultLogicalShape);
resultVals[resultLinearIndex] = srcVals[i];
}
}
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, llvmStructTy);
rewriter.replaceOp(op, {resultStruct});
return success();
}
};
struct ViewOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::ViewOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::ViewOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::ViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We cannot directly
// rewriter.replaceOp(op, adaptor.src());
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().cast<RankedTensorType>();
auto resultLayout =
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
Value view = getStructFromElements(loc, vals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::MakeRangeOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
auto shape = rankedTy.getShape();
auto blocked_layout =
rankedTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto elemTy = rankedTy.getElementType();
assert(elemTy.isInteger(32));
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
auto idxs =
emitIndicesForBlockedLayout(loc, rewriter, blocked_layout, shape);
unsigned elems = idxs.size();
SmallVector<Value> retVals(elems);
for (auto multiDim : llvm::enumerate(idxs)) {
assert(multiDim.value().size() == 1);
retVals[multiDim.index()] =
rewriter.create<LLVM::AddOp>(loc, multiDim.value()[0], start);
}
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
Value result = getStructFromElements(loc, retVals, rewriter, structTy);
rewriter.replaceOp(op, result);
return success();
}
};
struct LoadOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value ptr = adaptor.ptr();
Value mask = adaptor.mask();
Value other = adaptor.other();
auto resultTy = op.result().getType().cast<RankedTensorType>();
auto blockedLayout =
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto shape = resultTy.getShape();
// TODO: Handle AxisInfo
// vecWidth = std::min(nts, aln)
// TODO: special processing for mma_first_row in legacy codes
assert(blockedLayout && "LoadOp only accepts blocked_layout");
unsigned vecWidth =
blockedLayout.getSizePerThread()[blockedLayout.getOrder()[0]];
auto elemTy = resultTy.getElementType();
unsigned numElems = getElemsPerThread(blockedLayout, shape);
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter);
auto maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
auto otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
unsigned nbits = elemTy.isa<FloatType>()
? elemTy.cast<FloatType>().getWidth()
: elemTy.cast<IntegerType>().getWidth();
// unsigned dtsize = nbits / 8;
int max_word_width = std::max<int>(32, nbits);
int tot_width = nbits * vecWidth;
int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width);
// TODO: currently disable until supported in `store`
bool has_l2_evict_policy = false;
// TODO: (goostavz) handle when other is const but not splat, which
// should be rarely seen
bool otherIsSplatConstInt = false;
DenseElementsAttr constAttr;
int64_t splatVal = 0;
if (elemTy.isa<IntegerType>() &&
matchPattern(op.other(), m_Constant(&constAttr)) &&
constAttr.isSplat()) {
otherIsSplatConstInt = true;
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
}
SmallVector<Value> loadedVals;
for (size_t i = 0; i < numElems; i += vecWidth) {
Value ptr = ptrVals[i];
// TODO: Handle the optimization if ptr is from GEP and the idx is
// constant
// This should be a canonicalization pattern in LLVM Dialect
unsigned in_off = 0;
Value pred = maskVals[i];
// ---
// create inline asm string
// ---
// TODO: (Superjomn) refactor with AsmInstr abstraction
std::ostringstream asmOss;
asmOss << "@$" << n_words; // predicate
asmOss << " ld";
if (op.isVolatile()) {
asmOss << ".volatile";
}
asmOss << ".global";
if (op.cache() == triton::CacheModifier::CA)
asmOss << ".ca";
if (op.cache() == triton::CacheModifier::CG)
asmOss << ".cg";
if (op.evict() == triton::EvictionPolicy::EVICT_FIRST)
asmOss << ".L1::evict_first";
if (op.evict() == triton::EvictionPolicy::EVICT_LAST)
asmOss << ".L1::evict_last";
if (has_l2_evict_policy)
asmOss << ".L2::cache_hint";
if (n_words > 1)
asmOss << ".v" << n_words; // vector width
asmOss << ".b" << width; // word size
asmOss << " {";
for (int i = 0; i < n_words; i++) { // return values
if (i > 0)
asmOss << ",";
asmOss << "$" << i;
}
asmOss << "}";
asmOss << ", [ $" << n_words + 1; // load
asmOss << " + " << in_off << "]"; // constant offset
if (has_l2_evict_policy)
asmOss << ", $" << n_words + 2;
asmOss << ";";
SmallVector<Value> others;
for (size_t ii = 0; ii < n_words; ii++) {
size_t size = width / nbits;
auto vecTy = LLVM::getFixedVectorType(elemTy, size);
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (size_t s = 0; s < size; s++) {
Value falseVal = otherVals[i + ii * size + s];
Value sVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
v = rewriter.create<LLVM::InsertElementOp>(loc, vecTy, v, falseVal,
sVal);
}
v = rewriter.create<LLVM::BitcastOp>(
loc, IntegerType::get(getContext(), width), v);
asmOss << "\n ";
asmOss << "@!$" << n_words << " mov.u" << width;
asmOss << " $" << ii << ", ";
std::ios_base::fmtflags flags(asmOss.flags());
if (otherIsSplatConstInt)
asmOss << "0x" << std::hex << splatVal;
else {
asmOss << "$" << n_words + has_l2_evict_policy + 2 + ii;
others.push_back(v);
}
asmOss.flags(flags);
asmOss << ";";
}
// ---
// create inline ASM signature
// ---
SmallVector<Type> retTys(n_words, IntegerType::get(getContext(), width));
Type retTy = retTys.size() > 1
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
: retTys[0];
// ---
// create inline ASM constraints
// ---
std::string asmCstrt;
for (int ii = 0; ii < n_words; ii++) {
if (ii > 0)
asmCstrt += ",";
asmCstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
}
asmCstrt += ",b,l";
for (size_t ii = 0; ii < others.size(); ii++) {
asmCstrt += ",";
asmCstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
}
if (has_l2_evict_policy) {
asmCstrt += ",l";
}
// ---
// finally call inline ASM
// ---
SmallVector<Value> args = {pred, ptr};
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
loc, retTy, /*operands=*/args, /*asm_string=*/asmOss.str(),
/*constraints=*/asmCstrt, /*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
Value ret = inlineAsmOp.getResult(0);
// ---
// extract and store return values
// ---
SmallVector<Value> rets;
for (unsigned int ii = 0; ii < n_words; ii++) {
Value curr = nullptr;
if (retTy.isa<LLVM::LLVMStructType>()) {
curr = rewriter.create<LLVM::ExtractValueOp>(
loc, IntegerType::get(getContext(), width), ret,
rewriter.getI64ArrayAttr(ii));
} else {
curr = ret;
}
curr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::getFixedVectorType(elemTy, width / nbits), curr);
rets.push_back(curr);
}
int tmp = (width / nbits);
for (size_t ii = 0; ii < vecWidth; ii++) {
Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
Value loaded = rewriter.create<LLVM::ExtractElementOp>(
loc, elemTy, rets[ii / tmp], vecIdx);
loadedVals.push_back(loaded);
}
}
Type llvmResultStructTy = getTypeConverter()->convertType(resultTy);
Value resultStruct =
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
rewriter.replaceOp(op, {resultStruct});
return success();
}
};
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
public:
using TypeConverter::convertType;
TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option,
const DataLayoutAnalysis *analysis = nullptr)
: LLVMTypeConverter(ctx, option, analysis) {
addConversion([&](triton::PointerType type) -> llvm::Optional<Type> {
return convertTritonPointerType(type);
});
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
return convertTritonTensorType(type);
});
}
Type convertTritonPointerType(triton::PointerType type) {
return LLVM::LLVMPointerType::get(type.getPointeeType(),
type.getAddressSpace());
}
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
Attribute layout = type.getEncoding();
if (auto blocked_layout = layout.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
unsigned numElementsPerThread =
getElemsPerThread(blocked_layout, type.getShape());
SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
} else if (auto mma_layout = layout.dyn_cast<TritonGPUMmaEncodingAttr>()) {
// TODO: Not implemented
return llvm::None;
} else if (auto shared_layout =
layout.dyn_cast<TritonGPUSharedEncodingAttr>()) {
// TODO: Not implemented
return llvm::None;
}
return llvm::None;
}
};
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps) {
patterns.add<::FuncOpConversion>(typeConverter, numWarps);
patterns.add<::ReturnOpConversion>(typeConverter);
patterns.add<BroadcastOpConversion>(typeConverter);
patterns.add<FuncOpConversion>(typeConverter, numWarps);
patterns.add<LoadOpConversion>(typeConverter);
patterns.add<MakeRangeOpConversion>(typeConverter);
patterns.add<ReturnOpConversion>(typeConverter);
patterns.add<ViewOpConversion>(typeConverter);
}
class ConvertTritonGPUToLLVM
@@ -225,29 +844,41 @@ public:
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
LLVMTypeConverter typeConverter(context);
mlir::LowerToLLVMOptions option(context);
// TODO: need confirm
option.overrideIndexBitwidth(32);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMConversionTarget target(*context, typeConverter);
RewritePatternSet patterns(context);
// TODO: (goostavz) Temporarily disable this, since the lowering of
// arithmetic ops in tensor format is not complete yet.
// Add arith's patterns to help convert scalar expression to LLVM.
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
// mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
// patterns);
int numWarps = extractNumWarps(mod);
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
namespace mlir {
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx), typeConverter(typeConverter) {
addLegalDialect<LLVM::LLVMDialect>();
addLegalDialect<NVVM::NVVMDialect>();
// addIllegalDialect<triton::TritonDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
namespace triton {

View File

@@ -3,7 +3,7 @@
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.maxntid = 128 : i32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
@@ -13,3 +13,128 @@ func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
}
} // end module
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_load
func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK: llvm.inline_asm
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: vectorized_load
func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v4.b32
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v4.b32
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: vectorized_load_f16
func @vectorized_load_f16(%a_ptr_init : tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v2.b32
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v2.b32
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0>
return
}
}
// -----
// TODO: Pending on the support of isSplat constant
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other
func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
}
}
// TODO: Add a testcase to verify the optimization when ptr of the LoadOp
// is from a GEP with const idx
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_view_broadcast
func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
// CHECK: llvm.mlir.undef
// CHECK: %[[T0:.*]] = llvm.extractvalue
// CHECK: %[[T1:.*]] = llvm.extractvalue
%0 = tt.view %arg : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T1]]
// CHECK: llvm.insertvalue %[[T1]]
// CHECK: llvm.insertvalue %[[T1]]
// CHECK: llvm.insertvalue %[[T1]]
%1 = tt.broadcast %0 : (tensor<256x1xf32,#blocked2>) -> tensor<256x4xf32, #blocked2>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_make_range
func @basic_make_range() {
// CHECK: nvvm.read.ptx.sreg.tid.x
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
return
}
}
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
// #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
// #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
// module attributes {"triton_gpu.num-warps" = 4 : i32} {
// func @debut_kernel(%lb : index, %A : !tt.ptr<f32>, %B : !tt.ptr<f32>, %C : !tt.ptr<f32>) {
// %cst = arith.constant dense<true> : tensor<256xi1, #blocked0>
// %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
// %cst_1 = arith.constant dense<true> : tensor<1024x256xi1, #blocked1>
// %cst_2 = arith.constant dense<true> : tensor<256x2048xi1, #blocked2>
// %a_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
// %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
// %4 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<1x256xf32,#blocked1>
// %5 = tt.broadcast %4 : (tensor<1x256xf32,#blocked1>) -> tensor<1024x256xf32, #blocked1>
// %6 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
// %7 = tt.broadcast %6 : (tensor<256x1xf32,#blocked2>) -> tensor<256x2048xf32, #blocked2>
// %b_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<1024x256x!tt.ptr<f32>, #blocked1>
// %c_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<256x2048x!tt.ptr<f32>, #blocked2>
// tt.store %b_ptr_init, %5, %cst_1, : tensor<1024x256xf32, #blocked1>
// tt.store %c_ptr_init, %7, %cst_2, : tensor<256x2048xf32, #blocked2>
// return
// }
// }

View File

@@ -4,7 +4,7 @@
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
// CHECK: define void @test_empty_kernel
// CHECK: !nvvm.annotations
// CHECK: !{void (i64, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128}
// CHECK: !{void (i32, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128}
module attributes {"triton_gpu.num-warps" = 4 : i32} {