|
|
|
@@ -1,5 +1,6 @@
|
|
|
|
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
|
|
|
|
#include "../PassDetail.h"
|
|
|
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
|
|
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
|
|
|
|
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
|
|
|
|
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
|
|
|
@@ -307,16 +308,7 @@ static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
|
|
|
|
|
return linear_index;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename SourceOp>
|
|
|
|
|
class ConvertTritonGPUOpToLLVMPattern
|
|
|
|
|
: public ConvertOpToLLVMPattern<SourceOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpAdaptor = typename SourceOp::Adaptor;
|
|
|
|
|
|
|
|
|
|
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
|
|
|
|
PatternBenefit benefit = 1)
|
|
|
|
|
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
|
|
|
|
|
|
|
|
|
struct ConvertTritonGPUOpToLLVMPatternBase {
|
|
|
|
|
SmallVector<Value>
|
|
|
|
|
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
@@ -329,18 +321,18 @@ public:
|
|
|
|
|
}
|
|
|
|
|
return results;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Value getStructFromElements(Location loc, ValueRange resultVals,
|
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
|
Type structType) const {
|
|
|
|
|
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
|
|
|
|
for (auto v : llvm::enumerate(resultVals)) {
|
|
|
|
|
llvmStruct = rewriter.create<LLVM::InsertValueOp>(
|
|
|
|
|
loc, structType, llvmStruct, v.value(),
|
|
|
|
|
rewriter.getI64ArrayAttr(v.index()));
|
|
|
|
|
}
|
|
|
|
|
return llvmStruct;
|
|
|
|
|
}
|
|
|
|
|
template <typename SourceOp>
|
|
|
|
|
class ConvertTritonGPUOpToLLVMPattern
|
|
|
|
|
: public ConvertOpToLLVMPattern<SourceOp>,
|
|
|
|
|
public ConvertTritonGPUOpToLLVMPatternBase {
|
|
|
|
|
public:
|
|
|
|
|
using OpAdaptor = typename SourceOp::Adaptor;
|
|
|
|
|
|
|
|
|
|
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
|
|
|
|
PatternBenefit benefit = 1)
|
|
|
|
|
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
|
|
|
|
Location loc, Value linear,
|
|
|
|
@@ -432,24 +424,18 @@ public:
|
|
|
|
|
for (unsigned blockOffset = 0;
|
|
|
|
|
blockOffset <
|
|
|
|
|
shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]);
|
|
|
|
|
++blockOffset) {
|
|
|
|
|
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k];
|
|
|
|
|
++warpOffset) {
|
|
|
|
|
++blockOffset)
|
|
|
|
|
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
|
|
|
|
|
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
|
|
|
|
|
++threadOffset) {
|
|
|
|
|
++threadOffset)
|
|
|
|
|
for (unsigned elemOffset = 0; elemOffset < sizePerThread[k];
|
|
|
|
|
++elemOffset) {
|
|
|
|
|
++elemOffset)
|
|
|
|
|
offset[k].push_back(blockOffset * sizePerThread[k] *
|
|
|
|
|
threadsPerWarp[k] * warpsPerCTA[k] +
|
|
|
|
|
warpOffset * sizePerThread[k] *
|
|
|
|
|
threadsPerWarp[k] +
|
|
|
|
|
threadOffset * sizePerThread[k] + elemOffset);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// step 3, add offset to base, and reorder the sequence of indices,
|
|
|
|
|
// to guarantee that elems in a same sizePerThread are adjacent in
|
|
|
|
|
// order
|
|
|
|
@@ -535,9 +521,9 @@ struct SplatOpConversion
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), the
|
|
|
|
|
// logic is the same as triton::SplatOp, so the underlying implementation is
|
|
|
|
|
// reused.
|
|
|
|
|
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr),
|
|
|
|
|
// the logic is the same as triton::SplatOp, so the underlying implementation
|
|
|
|
|
// is reused.
|
|
|
|
|
struct ArithConstantSplatOpConversion
|
|
|
|
|
: public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> {
|
|
|
|
|
using ConvertTritonGPUOpToLLVMPattern<
|
|
|
|
@@ -576,20 +562,104 @@ struct ArithConstantSplatOpConversion
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Contains some helper functions for both Load and Store conversions.
|
|
|
|
|
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
|
|
|
|
LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
|
|
|
|
: AxisAnalysisPass(axisAnalysisPass) {}
|
|
|
|
|
|
|
|
|
|
// Get corresponding LLVM element values of \param value.
|
|
|
|
|
SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
|
|
|
|
const BlockedEncodingAttr &layout,
|
|
|
|
|
TypeConverter *typeConverter,
|
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
|
Location loc) const {
|
|
|
|
|
if (!value)
|
|
|
|
|
return {};
|
|
|
|
|
|
|
|
|
|
auto ty = value.getType().cast<RankedTensorType>();
|
|
|
|
|
auto shape = ty.getShape();
|
|
|
|
|
// Here, we assume that all inputs should have a blockedLayout
|
|
|
|
|
|
|
|
|
|
unsigned valueElems = getElemsPerThread(layout, shape);
|
|
|
|
|
|
|
|
|
|
auto llvmElemTy = typeConverter->convertType(ty.getElementType());
|
|
|
|
|
auto llvmElemPtrPtrTy =
|
|
|
|
|
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
|
|
|
|
|
|
|
|
|
|
auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter);
|
|
|
|
|
return valueVals;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get the blocked layout.
|
|
|
|
|
std::tuple<BlockedEncodingAttr, unsigned> getLayout(Value val) const {
|
|
|
|
|
auto ty = val.getType().cast<RankedTensorType>();
|
|
|
|
|
// Here, we assume that all inputs should have a blockedLayout
|
|
|
|
|
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
|
|
|
|
auto shape = ty.getShape();
|
|
|
|
|
unsigned valueElems = getElemsPerThread(layout, shape);
|
|
|
|
|
return std::make_tuple(layout, valueElems);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
|
|
|
|
|
auto axisInfo = getAxisInfo(val);
|
|
|
|
|
|
|
|
|
|
auto order = layout.getOrder();
|
|
|
|
|
|
|
|
|
|
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
|
|
|
|
|
unsigned maxContig = axisInfo->getContiguity(order[0]);
|
|
|
|
|
unsigned alignment = std::min(maxMultiple, maxContig);
|
|
|
|
|
return alignment;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsigned getVectorizeSize(Value ptr,
|
|
|
|
|
const BlockedEncodingAttr &layout) const {
|
|
|
|
|
auto axisInfo = getAxisInfo(ptr);
|
|
|
|
|
auto contig = axisInfo->getContiguity();
|
|
|
|
|
// Here order should be ordered by contiguous first, so the first element
|
|
|
|
|
// should have the largest contiguous.
|
|
|
|
|
auto order = layout.getOrder();
|
|
|
|
|
unsigned align = getAlignment(ptr, layout);
|
|
|
|
|
|
|
|
|
|
auto getTensorShape = [](Value val) -> ArrayRef<int64_t> {
|
|
|
|
|
auto ty = val.getType().cast<RankedTensorType>();
|
|
|
|
|
auto shape = ty.getShape();
|
|
|
|
|
return shape;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// unsigned contigPerThread = layout.getSizePerThread()[order[0]];
|
|
|
|
|
unsigned contigPerThread = getElemsPerThread(layout, getTensorShape(ptr));
|
|
|
|
|
|
|
|
|
|
unsigned vec = std::min(align, contigPerThread);
|
|
|
|
|
|
|
|
|
|
return vec;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
|
|
|
|
|
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
|
|
|
|
|
return it->getValue();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return llvm::Optional<AxisInfo>{};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
AxisInfoAnalysis &AxisAnalysisPass;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct StoreOpConversion
|
|
|
|
|
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp> {
|
|
|
|
|
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>,
|
|
|
|
|
public LoadStoreConversionBase {
|
|
|
|
|
using ConvertTritonGPUOpToLLVMPattern<
|
|
|
|
|
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
StoreOpConversion(LLVMTypeConverter &converter,
|
|
|
|
|
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
|
|
|
|
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
|
|
|
|
|
AxisAnalysisPass(axisAnalysisPass) {}
|
|
|
|
|
LoadStoreConversionBase(axisAnalysisPass) {}
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Value ptr = op.ptr();
|
|
|
|
|
Value mask = op.mask();
|
|
|
|
|
Value value = op.value();
|
|
|
|
@@ -598,125 +668,52 @@ struct StoreOpConversion
|
|
|
|
|
Value llMask = adaptor.mask();
|
|
|
|
|
Value llValue = adaptor.value();
|
|
|
|
|
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
|
|
|
|
|
|
|
|
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
if (!valueTy)
|
|
|
|
|
return failure();
|
|
|
|
|
Type valueElemTy =
|
|
|
|
|
getTypeConverter()->convertType(valueTy.getElementType());
|
|
|
|
|
|
|
|
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
auto [layout, numElems] = getLayout(ptr);
|
|
|
|
|
|
|
|
|
|
auto getLLVMElems =
|
|
|
|
|
[&](Value value, Value llValue,
|
|
|
|
|
const BlockedEncodingAttr &layout) -> SmallVector<Value> {
|
|
|
|
|
auto ty = value.getType().cast<RankedTensorType>();
|
|
|
|
|
auto shape = ty.getShape();
|
|
|
|
|
// Here, we assume that all inputs should have a blockedLayout
|
|
|
|
|
auto ptrElems =
|
|
|
|
|
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
|
|
|
|
|
auto valueElems =
|
|
|
|
|
getLLVMElems(value, llValue, layout, getTypeConverter(), rewriter, loc);
|
|
|
|
|
assert(ptrElems.size() == valueElems.size());
|
|
|
|
|
|
|
|
|
|
unsigned valueElems = getElemsPerThread(layout, shape);
|
|
|
|
|
|
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(ty.getElementType());
|
|
|
|
|
auto llvmElemPtrPtrTy =
|
|
|
|
|
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
|
|
|
|
|
|
|
|
|
|
auto valueVals =
|
|
|
|
|
getElementsFromStruct(loc, llValue, valueElems, rewriter);
|
|
|
|
|
return valueVals;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto getLayout =
|
|
|
|
|
[&](Value val) -> std::tuple<BlockedEncodingAttr, unsigned> {
|
|
|
|
|
auto ty = val.getType().cast<RankedTensorType>();
|
|
|
|
|
auto shape = ty.getShape();
|
|
|
|
|
// Here, we assume that all inputs should have a blockedLayout
|
|
|
|
|
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
|
|
|
|
|
|
|
|
|
unsigned valueElems = getElemsPerThread(layout, shape);
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(layout, valueElems);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto [ptrLayout, ptrNumElems] = getLayout(ptr);
|
|
|
|
|
auto [valueLayout, valueNumElems] = getLayout(value);
|
|
|
|
|
|
|
|
|
|
auto ptrElems = getLLVMElems(ptr, llPtr, ptrLayout);
|
|
|
|
|
auto valueElems = getLLVMElems(value, llValue, valueLayout);
|
|
|
|
|
SmallVector<Value> maskElems;
|
|
|
|
|
if (llMask) {
|
|
|
|
|
auto [maskLayout, maskNumElems] = getLayout(mask);
|
|
|
|
|
maskElems = getLLVMElems(mask, llMask, maskLayout);
|
|
|
|
|
maskElems =
|
|
|
|
|
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
|
|
|
|
|
assert(valueElems.size() == maskElems.size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto getAlign = [this](Value val,
|
|
|
|
|
const BlockedEncodingAttr &layout) -> unsigned {
|
|
|
|
|
auto axisInfo = getAxisInfo(val);
|
|
|
|
|
assert(axisInfo.hasValue());
|
|
|
|
|
|
|
|
|
|
auto order = layout.getOrder();
|
|
|
|
|
|
|
|
|
|
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
|
|
|
|
|
unsigned maxContig = axisInfo->getContiguity(order[0]);
|
|
|
|
|
unsigned alignment = std::min(maxMultiple, maxContig);
|
|
|
|
|
return alignment;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// get align
|
|
|
|
|
auto getVec = [this,
|
|
|
|
|
&getAlign](Value val,
|
|
|
|
|
const BlockedEncodingAttr &layout) -> unsigned {
|
|
|
|
|
auto axisInfo = getAxisInfo(val);
|
|
|
|
|
auto contig = axisInfo->getContiguity();
|
|
|
|
|
// Here order should be ordered by contiguous first, so the first element
|
|
|
|
|
// should have the largest contiguous.
|
|
|
|
|
auto order = layout.getOrder();
|
|
|
|
|
unsigned align = getAlign(val, layout);
|
|
|
|
|
|
|
|
|
|
assert(!order.empty());
|
|
|
|
|
// Is this right?
|
|
|
|
|
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
|
|
|
|
|
unsigned vec = std::min(align, contigPerThread);
|
|
|
|
|
|
|
|
|
|
// TODO(Superjomn) Consider the is_mma_first_row in the legacy code
|
|
|
|
|
bool isMMAFirstRow = false;
|
|
|
|
|
|
|
|
|
|
if (isMMAFirstRow)
|
|
|
|
|
vec = std::min<size_t>(2, align);
|
|
|
|
|
|
|
|
|
|
return vec;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Determine the vectorization size
|
|
|
|
|
size_t vec = getVec(ptr, ptrLayout);
|
|
|
|
|
size_t vec = getVectorizeSize(ptr, layout);
|
|
|
|
|
|
|
|
|
|
const size_t dtsize = value.getType()
|
|
|
|
|
.cast<RankedTensorType>()
|
|
|
|
|
.getElementType()
|
|
|
|
|
.getIntOrFloatBitWidth() /
|
|
|
|
|
8;
|
|
|
|
|
const size_t dtsize =
|
|
|
|
|
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
|
|
|
|
const size_t valueElemNbits = dtsize * 8;
|
|
|
|
|
|
|
|
|
|
const int numVecs = ptrNumElems / vec;
|
|
|
|
|
for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) {
|
|
|
|
|
|
|
|
|
|
const int numVecs = numElems / vec;
|
|
|
|
|
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
|
|
|
|
// TODO: optimization when ptr is GEP with constant offset
|
|
|
|
|
size_t in_off = 0;
|
|
|
|
|
|
|
|
|
|
// pack sub-words (< 32/64bits) into words
|
|
|
|
|
// each load has width min(nbits*vec, 32/64)
|
|
|
|
|
// and there are (nbits * vec)/width of them
|
|
|
|
|
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
|
|
|
|
const int totalWidth = valueElemNbits * vec;
|
|
|
|
|
const int width = std::min(totalWidth, maxWordWidth);
|
|
|
|
|
const int nWords = std::max(1, totalWidth / width);
|
|
|
|
|
const int wordNElems = width / valueElemNbits;
|
|
|
|
|
const int vecNElems = totalWidth / valueElemNbits;
|
|
|
|
|
assert(wordNElems * nWords * numVecs == numElems);
|
|
|
|
|
|
|
|
|
|
assert(wordNElems * nWords * numVecs == valueElems.size());
|
|
|
|
|
|
|
|
|
|
// TODO(Superjomn) Add cache policy to store.
|
|
|
|
|
// TODO(Superjomn) deal with cache policy.
|
|
|
|
|
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
|
|
|
|
// TODO(Superjomn) Deal with cache policy here.
|
|
|
|
|
const bool hasL2EvictPolicy = false;
|
|
|
|
|
|
|
|
|
|
PTXBuilder ptxBuilder;
|
|
|
|
@@ -733,8 +730,9 @@ struct StoreOpConversion
|
|
|
|
|
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
|
|
|
|
// Insert each value element to the composition
|
|
|
|
|
for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) {
|
|
|
|
|
Value elem =
|
|
|
|
|
valueElems[vecIdx * vecNElems + wordIdx * wordNElems + elemIdx];
|
|
|
|
|
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
|
|
|
|
|
assert(elemOffset < valueElems.size());
|
|
|
|
|
Value elem = valueElems[elemOffset];
|
|
|
|
|
if (elem.getType().isInteger(1))
|
|
|
|
|
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
|
|
|
|
|
elem = rewriter.create<LLVM::BitcastOp>(loc, valueElemTy, elem);
|
|
|
|
@@ -751,13 +749,17 @@ struct StoreOpConversion
|
|
|
|
|
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
|
|
|
|
// the values share one predicate? Here assume all the mask values are
|
|
|
|
|
// the same.
|
|
|
|
|
Value maskVal =
|
|
|
|
|
llMask ? maskElems[vecIdx]
|
|
|
|
|
llMask ? maskElems[vecStart]
|
|
|
|
|
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
|
|
|
|
rewriter.getIntegerType(1), 1);
|
|
|
|
|
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
|
|
|
|
|
|
|
|
|
|
auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off);
|
|
|
|
|
auto *asmAddr =
|
|
|
|
|
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
|
|
|
|
|
|
|
|
|
ptxStoreInstr(asmAddr, asmArgList);
|
|
|
|
|
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
|
|
|
@@ -782,17 +784,6 @@ struct StoreOpConversion
|
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
|
|
|
|
|
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
|
|
|
|
|
return it->getValue();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return llvm::Optional<AxisInfo>{};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
AxisInfoAnalysis &AxisAnalysisPass;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct BroadcastOpConversion
|
|
|
|
@@ -909,8 +900,10 @@ struct ViewOpConversion
|
|
|
|
|
|
|
|
|
|
struct MakeRangeOpConversion
|
|
|
|
|
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
|
|
|
|
using ConvertTritonGPUOpToLLVMPattern<
|
|
|
|
|
triton::MakeRangeOp>::ConvertTritonGPUOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit)
|
|
|
|
|
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(converter,
|
|
|
|
|
benefit) {}
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
|
|
|
@@ -918,13 +911,12 @@ struct MakeRangeOpConversion
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
auto shape = rankedTy.getShape();
|
|
|
|
|
auto blocked_layout =
|
|
|
|
|
rankedTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
|
|
|
|
auto layout = rankedTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
|
|
|
|
|
|
|
|
auto elemTy = rankedTy.getElementType();
|
|
|
|
|
assert(elemTy.isInteger(32));
|
|
|
|
|
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
|
|
|
|
|
auto idxs =
|
|
|
|
|
emitIndicesForBlockedLayout(loc, rewriter, blocked_layout, shape);
|
|
|
|
|
auto idxs = emitIndicesForBlockedLayout(loc, rewriter, layout, shape);
|
|
|
|
|
unsigned elems = idxs.size();
|
|
|
|
|
SmallVector<Value> retVals(elems);
|
|
|
|
|
for (auto multiDim : llvm::enumerate(idxs)) {
|
|
|
|
@@ -941,92 +933,116 @@ struct MakeRangeOpConversion
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct LoadOpConversion
|
|
|
|
|
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp> {
|
|
|
|
|
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
|
|
|
|
|
public LoadStoreConversionBase {
|
|
|
|
|
using ConvertTritonGPUOpToLLVMPattern<
|
|
|
|
|
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
|
|
|
|
|
|
|
|
|
|
LoadOpConversion(LLVMTypeConverter &converter,
|
|
|
|
|
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
|
|
|
|
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
|
|
|
|
|
LoadStoreConversionBase(axisAnalysisPass) {}
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
Value ptr = adaptor.ptr();
|
|
|
|
|
Value mask = adaptor.mask();
|
|
|
|
|
Value other = adaptor.other();
|
|
|
|
|
auto resultTy = op.result().getType().cast<RankedTensorType>();
|
|
|
|
|
auto blockedLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
|
|
|
|
auto shape = resultTy.getShape();
|
|
|
|
|
|
|
|
|
|
// TODO: Handle AxisInfo
|
|
|
|
|
// vecWidth = std::min(nts, aln)
|
|
|
|
|
// TODO: special processing for mma_first_row in legacy codes
|
|
|
|
|
assert(blockedLayout && "LoadOp only accepts blocked_layout");
|
|
|
|
|
unsigned vecWidth =
|
|
|
|
|
blockedLayout.getSizePerThread()[blockedLayout.getOrder()[0]];
|
|
|
|
|
Value ptr = op.ptr();
|
|
|
|
|
Value mask = op.mask();
|
|
|
|
|
Value other = op.other();
|
|
|
|
|
|
|
|
|
|
auto elemTy = resultTy.getElementType();
|
|
|
|
|
unsigned numElems = getElemsPerThread(blockedLayout, shape);
|
|
|
|
|
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter);
|
|
|
|
|
SmallVector<Value> maskVals;
|
|
|
|
|
if (mask) {
|
|
|
|
|
maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
|
|
|
|
|
Value llPtr = adaptor.ptr();
|
|
|
|
|
Value llMask = adaptor.mask();
|
|
|
|
|
Value llOther = adaptor.other();
|
|
|
|
|
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
|
|
|
|
|
|
|
|
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
if (!valueTy)
|
|
|
|
|
return failure();
|
|
|
|
|
Type valueElemTy =
|
|
|
|
|
getTypeConverter()->convertType(valueTy.getElementType());
|
|
|
|
|
|
|
|
|
|
auto [layout, numElems] = getLayout(ptr);
|
|
|
|
|
|
|
|
|
|
auto ptrElems =
|
|
|
|
|
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
|
|
|
|
|
assert(ptrElems.size() == numElems);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> maskElems;
|
|
|
|
|
if (llMask) {
|
|
|
|
|
maskElems =
|
|
|
|
|
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
|
|
|
|
|
assert(ptrElems.size() == maskElems.size());
|
|
|
|
|
}
|
|
|
|
|
SmallVector<Value> otherVals;
|
|
|
|
|
if (other) {
|
|
|
|
|
otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
|
|
|
|
|
}
|
|
|
|
|
unsigned nbits = elemTy.isa<FloatType>()
|
|
|
|
|
? elemTy.cast<FloatType>().getWidth()
|
|
|
|
|
: elemTy.cast<IntegerType>().getWidth();
|
|
|
|
|
// unsigned dtsize = nbits / 8;
|
|
|
|
|
int max_word_width = std::max<int>(32, nbits);
|
|
|
|
|
int tot_width = nbits * vecWidth;
|
|
|
|
|
int width = std::min(tot_width, max_word_width);
|
|
|
|
|
int n_words = std::max(1, tot_width / width);
|
|
|
|
|
// TODO: currently disable until supported in `store`
|
|
|
|
|
bool has_l2_evict_policy = false;
|
|
|
|
|
|
|
|
|
|
// Determine the vectorization size
|
|
|
|
|
size_t vec = getVectorizeSize(ptr, layout);
|
|
|
|
|
|
|
|
|
|
const size_t dtsize =
|
|
|
|
|
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
|
|
|
|
const size_t valueElemNbits = dtsize * 8;
|
|
|
|
|
|
|
|
|
|
const int numVecs = numElems / vec;
|
|
|
|
|
|
|
|
|
|
// TODO: (goostavz) handle when other is const but not splat, which
|
|
|
|
|
// should be rarely seen
|
|
|
|
|
bool otherIsSplatConstInt = false;
|
|
|
|
|
DenseElementsAttr constAttr;
|
|
|
|
|
int64_t splatVal = 0;
|
|
|
|
|
if (elemTy.isa<IntegerType>() &&
|
|
|
|
|
if (valueElemTy.isa<IntegerType>() &&
|
|
|
|
|
matchPattern(op.other(), m_Constant(&constAttr)) &&
|
|
|
|
|
constAttr.isSplat()) {
|
|
|
|
|
otherIsSplatConstInt = true;
|
|
|
|
|
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto otherElems =
|
|
|
|
|
getLLVMElems(other, llOther, layout, getTypeConverter(), rewriter, loc);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> loadedVals;
|
|
|
|
|
for (size_t i = 0; i < numElems; i += vecWidth) {
|
|
|
|
|
Value ptr = ptrVals[i];
|
|
|
|
|
// TODO: Handle the optimization if ptr is from GEP and the idx is
|
|
|
|
|
// constant. This should be a canonicalization pattern in LLVM Dialect
|
|
|
|
|
unsigned in_off = 0;
|
|
|
|
|
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
|
|
|
|
// TODO: optimization when ptr is GEP with constant offset
|
|
|
|
|
size_t in_off = 0;
|
|
|
|
|
|
|
|
|
|
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
|
|
|
|
const int totalWidth = valueElemNbits * vec;
|
|
|
|
|
const int width = std::min(totalWidth, maxWordWidth);
|
|
|
|
|
const int nWords = std::max(1, totalWidth / width);
|
|
|
|
|
const int wordNElems = width / valueElemNbits;
|
|
|
|
|
const int vecNElems = totalWidth / valueElemNbits;
|
|
|
|
|
assert(wordNElems * nWords * numVecs == numElems);
|
|
|
|
|
|
|
|
|
|
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
|
|
|
|
// TODO(Superjomn) Deal with cache policy here.
|
|
|
|
|
const bool hasL2EvictPolicy = false;
|
|
|
|
|
|
|
|
|
|
PTXBuilder ptxBuilder;
|
|
|
|
|
auto &ld = *ptxBuilder.create<PtxIOInstr>("ld");
|
|
|
|
|
|
|
|
|
|
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
|
|
|
|
// the values share one predicate? Here assume all the mask values are
|
|
|
|
|
// the same.
|
|
|
|
|
Value pred =
|
|
|
|
|
mask ? maskVals[i]
|
|
|
|
|
mask ? maskElems[vecStart]
|
|
|
|
|
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
|
|
|
|
|
rewriter.getIntegerType(1), 1);
|
|
|
|
|
|
|
|
|
|
// ---
|
|
|
|
|
// create inline asm string
|
|
|
|
|
// ---
|
|
|
|
|
|
|
|
|
|
const std::string readConstrait =
|
|
|
|
|
const std::string readConstraint =
|
|
|
|
|
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
|
|
|
|
const std::string writeConstrait =
|
|
|
|
|
const std::string writeConstraint =
|
|
|
|
|
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
|
|
|
|
|
|
|
|
|
PTXBuilder ptxBuilder;
|
|
|
|
|
PtxIOInstr &ld = *ptxBuilder.create<PtxIOInstr>("ld");
|
|
|
|
|
|
|
|
|
|
// prepare asm operands
|
|
|
|
|
auto *dstsOpr = ptxBuilder.newListOperand();
|
|
|
|
|
for (int i = 0; i < n_words; i++) {
|
|
|
|
|
auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations
|
|
|
|
|
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
|
|
|
|
|
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
|
|
|
|
dstsOpr->listAppend(opr);
|
|
|
|
|
}
|
|
|
|
|
auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off);
|
|
|
|
|
|
|
|
|
|
auto *addrOpr =
|
|
|
|
|
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
|
|
|
|
|
|
|
|
|
// Define the instruction opcode
|
|
|
|
|
ld.predicate(pred, "b")
|
|
|
|
@@ -1037,11 +1053,12 @@ struct LoadOpConversion
|
|
|
|
|
.o("L1::evict_first",
|
|
|
|
|
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
|
|
|
|
|
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
|
|
|
|
|
.o("L1::cache_hint", has_l2_evict_policy)
|
|
|
|
|
.v(n_words)
|
|
|
|
|
.o("L1::cache_hint", hasL2EvictPolicy)
|
|
|
|
|
.v(nWords)
|
|
|
|
|
.b(width);
|
|
|
|
|
|
|
|
|
|
PTXBuilder::Operand *evictOpr{};
|
|
|
|
|
|
|
|
|
|
// Here lack a mlir::Value to bind to this operation, so disabled.
|
|
|
|
|
// if (has_l2_evict_policy)
|
|
|
|
|
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
|
|
|
|
@@ -1053,16 +1070,16 @@ struct LoadOpConversion
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> others;
|
|
|
|
|
if (other) {
|
|
|
|
|
for (size_t ii = 0; ii < n_words; ii++) {
|
|
|
|
|
for (size_t ii = 0; ii < nWords; ii++) {
|
|
|
|
|
PTXInstr &mov = *ptxBuilder.create<>("mov");
|
|
|
|
|
mov.predicateNot(pred, "b").o("u", width);
|
|
|
|
|
|
|
|
|
|
size_t size = width / nbits;
|
|
|
|
|
size_t size = width / valueElemNbits;
|
|
|
|
|
|
|
|
|
|
auto vecTy = LLVM::getFixedVectorType(elemTy, size);
|
|
|
|
|
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
|
|
|
|
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
|
|
|
|
for (size_t s = 0; s < size; s++) {
|
|
|
|
|
Value falseVal = otherVals[i + ii * size + s];
|
|
|
|
|
Value falseVal = otherElems[vecStart + ii * size + s];
|
|
|
|
|
Value sVal = createIndexAttrConstant(
|
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
|
|
|
|
v = rewriter.create<LLVM::InsertElementOp>(loc, vecTy, v, falseVal,
|
|
|
|
@@ -1075,7 +1092,7 @@ struct LoadOpConversion
|
|
|
|
|
if (otherIsSplatConstInt) {
|
|
|
|
|
opr = ptxBuilder.newConstantOperand(splatVal);
|
|
|
|
|
} else {
|
|
|
|
|
opr = ptxBuilder.newOperand(v, readConstrait);
|
|
|
|
|
opr = ptxBuilder.newOperand(v, readConstraint);
|
|
|
|
|
others.push_back(v);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -1086,7 +1103,7 @@ struct LoadOpConversion
|
|
|
|
|
// ---
|
|
|
|
|
// create inline ASM signature
|
|
|
|
|
// ---
|
|
|
|
|
SmallVector<Type> retTys(n_words, IntegerType::get(getContext(), width));
|
|
|
|
|
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
|
|
|
|
|
Type retTy = retTys.size() > 1
|
|
|
|
|
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
|
|
|
|
|
: retTys[0];
|
|
|
|
@@ -1097,7 +1114,8 @@ struct LoadOpConversion
|
|
|
|
|
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
|
|
|
|
|
loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(),
|
|
|
|
|
/*asm_string=*/ptxBuilder.dump(),
|
|
|
|
|
/*constraints=*/ptxBuilder.getConstrains(), /*has_side_effects=*/true,
|
|
|
|
|
/*constraints=*/ptxBuilder.getConstrains(),
|
|
|
|
|
/*has_side_effects=*/true,
|
|
|
|
|
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
|
|
|
|
|
/*operand_attrs=*/ArrayAttr());
|
|
|
|
|
Value ret = inlineAsmOp.getResult(0);
|
|
|
|
@@ -1106,8 +1124,8 @@ struct LoadOpConversion
|
|
|
|
|
// extract and store return values
|
|
|
|
|
// ---
|
|
|
|
|
SmallVector<Value> rets;
|
|
|
|
|
for (unsigned int ii = 0; ii < n_words; ii++) {
|
|
|
|
|
Value curr = nullptr;
|
|
|
|
|
for (unsigned int ii = 0; ii < nWords; ii++) {
|
|
|
|
|
Value curr;
|
|
|
|
|
if (retTy.isa<LLVM::LLVMStructType>()) {
|
|
|
|
|
curr = rewriter.create<LLVM::ExtractValueOp>(
|
|
|
|
|
loc, IntegerType::get(getContext(), width), ret,
|
|
|
|
@@ -1116,19 +1134,21 @@ struct LoadOpConversion
|
|
|
|
|
curr = ret;
|
|
|
|
|
}
|
|
|
|
|
curr = rewriter.create<LLVM::BitcastOp>(
|
|
|
|
|
loc, LLVM::getFixedVectorType(elemTy, width / nbits), curr);
|
|
|
|
|
loc, LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
|
|
|
|
|
curr);
|
|
|
|
|
rets.push_back(curr);
|
|
|
|
|
}
|
|
|
|
|
int tmp = (width / nbits);
|
|
|
|
|
for (size_t ii = 0; ii < vecWidth; ii++) {
|
|
|
|
|
int tmp = (width / valueElemNbits);
|
|
|
|
|
for (size_t ii = 0; ii < vec; ii++) {
|
|
|
|
|
Value vecIdx = createIndexAttrConstant(
|
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
|
|
|
|
|
Value loaded = rewriter.create<LLVM::ExtractElementOp>(
|
|
|
|
|
loc, elemTy, rets[ii / tmp], vecIdx);
|
|
|
|
|
loc, valueElemTy, rets[ii / tmp], vecIdx);
|
|
|
|
|
loadedVals.push_back(loaded);
|
|
|
|
|
}
|
|
|
|
|
} // end vec
|
|
|
|
|
Type llvmResultStructTy = getTypeConverter()->convertType(resultTy);
|
|
|
|
|
|
|
|
|
|
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
|
|
|
|
|
Value resultStruct =
|
|
|
|
|
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
|
|
|
|
|
rewriter.replaceOp(op, {resultStruct});
|
|
|
|
@@ -1272,11 +1292,16 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|
|
|
|
benefit);
|
|
|
|
|
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
|
|
|
|
|
benefit);
|
|
|
|
|
patterns.add<BinaryOpConversion<arith::MulIOp, LLVM::MulOp>>(typeConverter,
|
|
|
|
|
benefit);
|
|
|
|
|
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
|
|
|
|
|
benefit);
|
|
|
|
|
|
|
|
|
|
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
|
|
|
|
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
|
|
|
|
|
patterns.add<GEPOpConversion>(typeConverter, benefit);
|
|
|
|
|
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
|
|
|
|
patterns.add<LoadOpConversion>(typeConverter, benefit);
|
|
|
|
|
patterns.add<LoadOpConversion>(typeConverter, analysis, benefit);
|
|
|
|
|
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
|
|
|
|
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
|
|
|
|
patterns.add<SplatOpConversion>(typeConverter, benefit);
|
|
|
|
|