[Backend] Vectorize Load/Store Ops (#86)

This PR does the following things:

- Code refactoring on Load and Store op codegen, rewrite with same logic
and share much code
- Support the vectorized load/store
This commit is contained in:
Yan Chunwei
2022-09-07 03:28:09 +08:00
committed by GitHub
parent 35e346bcff
commit a9464f4993
10 changed files with 433 additions and 295 deletions

View File

@@ -27,9 +27,9 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
if (isSharedEncoding(result)) {
aliasInfo.insert(result);
pessimistic = false;
} else {
llvm::errs() << "op: " << op->getName() << "\n";
}
} else {
llvm::errs() << "op: " << op->getName() << "\n";
}
// XXX(Keren): triton ops don't support aliasing yet.
// else if (auto viewOp = dyn_cast<triton::ViewOp>(op) ||

View File

@@ -214,4 +214,4 @@ ChangeResult AxisInfoAnalysis::visitOperation(
return result;
}
} // namespace mlir
} // namespace mlir

View File

@@ -1,5 +1,6 @@
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
@@ -307,16 +308,7 @@ static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
return linear_index;
}
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
struct ConvertTritonGPUOpToLLVMPatternBase {
SmallVector<Value>
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
ConversionPatternRewriter &rewriter) const {
@@ -329,18 +321,18 @@ public:
}
return results;
}
};
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType) const {
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (auto v : llvm::enumerate(resultVals)) {
llvmStruct = rewriter.create<LLVM::InsertValueOp>(
loc, structType, llvmStruct, v.value(),
rewriter.getI64ArrayAttr(v.index()));
}
return llvmStruct;
}
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp>,
public ConvertTritonGPUOpToLLVMPatternBase {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
@@ -432,24 +424,18 @@ public:
for (unsigned blockOffset = 0;
blockOffset <
shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]);
++blockOffset) {
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k];
++warpOffset) {
++blockOffset)
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
++threadOffset) {
++threadOffset)
for (unsigned elemOffset = 0; elemOffset < sizePerThread[k];
++elemOffset) {
++elemOffset)
offset[k].push_back(blockOffset * sizePerThread[k] *
threadsPerWarp[k] * warpsPerCTA[k] +
warpOffset * sizePerThread[k] *
threadsPerWarp[k] +
threadOffset * sizePerThread[k] + elemOffset);
}
}
}
}
}
// step 3, add offset to base, and reorder the sequence of indices,
// to guarantee that elems in a same sizePerThread are adjacent in
// order
@@ -535,9 +521,9 @@ struct SplatOpConversion
}
};
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), the
// logic is the same as triton::SplatOp, so the underlying implementation is
// reused.
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr),
// the logic is the same as triton::SplatOp, so the underlying implementation
// is reused.
struct ArithConstantSplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> {
using ConvertTritonGPUOpToLLVMPattern<
@@ -576,20 +562,104 @@ struct ArithConstantSplatOpConversion
}
};
// Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
: AxisAnalysisPass(axisAnalysisPass) {}
// Get corresponding LLVM element values of \param value.
SmallVector<Value> getLLVMElems(Value value, Value llValue,
const BlockedEncodingAttr &layout,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Location loc) const {
if (!value)
return {};
auto ty = value.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
unsigned valueElems = getElemsPerThread(layout, shape);
auto llvmElemTy = typeConverter->convertType(ty.getElementType());
auto llvmElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter);
return valueVals;
}
// Get the blocked layout.
std::tuple<BlockedEncodingAttr, unsigned> getLayout(Value val) const {
auto ty = val.getType().cast<RankedTensorType>();
// Here, we assume that all inputs should have a blockedLayout
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto shape = ty.getShape();
unsigned valueElems = getElemsPerThread(layout, shape);
return std::make_tuple(layout, valueElems);
}
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
auto axisInfo = getAxisInfo(val);
auto order = layout.getOrder();
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
unsigned maxContig = axisInfo->getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
}
unsigned getVectorizeSize(Value ptr,
const BlockedEncodingAttr &layout) const {
auto axisInfo = getAxisInfo(ptr);
auto contig = axisInfo->getContiguity();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = layout.getOrder();
unsigned align = getAlignment(ptr, layout);
auto getTensorShape = [](Value val) -> ArrayRef<int64_t> {
auto ty = val.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
return shape;
};
// unsigned contigPerThread = layout.getSizePerThread()[order[0]];
unsigned contigPerThread = getElemsPerThread(layout, getTensorShape(ptr));
unsigned vec = std::min(align, contigPerThread);
return vec;
}
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
return it->getValue();
}
return llvm::Optional<AxisInfo>{};
}
protected:
AxisInfoAnalysis &AxisAnalysisPass;
};
struct StoreOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp> {
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
StoreOpConversion(LLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
AxisAnalysisPass(axisAnalysisPass) {}
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value ptr = op.ptr();
Value mask = op.mask();
Value value = op.value();
@@ -598,125 +668,52 @@ struct StoreOpConversion
Value llMask = adaptor.mask();
Value llValue = adaptor.value();
auto loc = op->getLoc();
MLIRContext *ctx = rewriter.getContext();
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
if (!valueTy)
return failure();
Type valueElemTy =
getTypeConverter()->convertType(valueTy.getElementType());
MLIRContext *ctx = rewriter.getContext();
auto loc = op->getLoc();
auto [layout, numElems] = getLayout(ptr);
auto getLLVMElems =
[&](Value value, Value llValue,
const BlockedEncodingAttr &layout) -> SmallVector<Value> {
auto ty = value.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
auto ptrElems =
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
auto valueElems =
getLLVMElems(value, llValue, layout, getTypeConverter(), rewriter, loc);
assert(ptrElems.size() == valueElems.size());
unsigned valueElems = getElemsPerThread(layout, shape);
auto llvmElemTy = getTypeConverter()->convertType(ty.getElementType());
auto llvmElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
auto valueVals =
getElementsFromStruct(loc, llValue, valueElems, rewriter);
return valueVals;
};
auto getLayout =
[&](Value val) -> std::tuple<BlockedEncodingAttr, unsigned> {
auto ty = val.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
unsigned valueElems = getElemsPerThread(layout, shape);
return std::make_tuple(layout, valueElems);
};
auto [ptrLayout, ptrNumElems] = getLayout(ptr);
auto [valueLayout, valueNumElems] = getLayout(value);
auto ptrElems = getLLVMElems(ptr, llPtr, ptrLayout);
auto valueElems = getLLVMElems(value, llValue, valueLayout);
SmallVector<Value> maskElems;
if (llMask) {
auto [maskLayout, maskNumElems] = getLayout(mask);
maskElems = getLLVMElems(mask, llMask, maskLayout);
maskElems =
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
assert(valueElems.size() == maskElems.size());
}
auto getAlign = [this](Value val,
const BlockedEncodingAttr &layout) -> unsigned {
auto axisInfo = getAxisInfo(val);
assert(axisInfo.hasValue());
auto order = layout.getOrder();
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
unsigned maxContig = axisInfo->getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
};
// get align
auto getVec = [this,
&getAlign](Value val,
const BlockedEncodingAttr &layout) -> unsigned {
auto axisInfo = getAxisInfo(val);
auto contig = axisInfo->getContiguity();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = layout.getOrder();
unsigned align = getAlign(val, layout);
assert(!order.empty());
// Is this right?
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
unsigned vec = std::min(align, contigPerThread);
// TODO(Superjomn) Consider the is_mma_first_row in the legacy code
bool isMMAFirstRow = false;
if (isMMAFirstRow)
vec = std::min<size_t>(2, align);
return vec;
};
// Determine the vectorization size
size_t vec = getVec(ptr, ptrLayout);
size_t vec = getVectorizeSize(ptr, layout);
const size_t dtsize = value.getType()
.cast<RankedTensorType>()
.getElementType()
.getIntOrFloatBitWidth() /
8;
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNbits = dtsize * 8;
const int numVecs = ptrNumElems / vec;
for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) {
const int numVecs = numElems / vec;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
// TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0;
// pack sub-words (< 32/64bits) into words
// each load has width min(nbits*vec, 32/64)
// and there are (nbits * vec)/width of them
const int maxWordWidth = std::max<int>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width);
const int wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
assert(wordNElems * nWords * numVecs == valueElems.size());
// TODO(Superjomn) Add cache policy to store.
// TODO(Superjomn) deal with cache policy.
// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
const bool hasL2EvictPolicy = false;
PTXBuilder ptxBuilder;
@@ -733,8 +730,9 @@ struct StoreOpConversion
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
// Insert each value element to the composition
for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) {
Value elem =
valueElems[vecIdx * vecNElems + wordIdx * wordNElems + elemIdx];
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
assert(elemOffset < valueElems.size());
Value elem = valueElems[elemOffset];
if (elem.getType().isInteger(1))
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = rewriter.create<LLVM::BitcastOp>(loc, valueElemTy, elem);
@@ -751,13 +749,17 @@ struct StoreOpConversion
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
}
// TODO(Superjomn) Need to check masks before vectorize the load for all
// the values share one predicate? Here assume all the mask values are
// the same.
Value maskVal =
llMask ? maskElems[vecIdx]
llMask ? maskElems[vecStart]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1);
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off);
auto *asmAddr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
ptxStoreInstr(asmAddr, asmArgList);
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
@@ -782,17 +784,6 @@ struct StoreOpConversion
rewriter.eraseOp(op);
return success();
}
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
return it->getValue();
}
return llvm::Optional<AxisInfo>{};
}
private:
AxisInfoAnalysis &AxisAnalysisPass;
};
struct BroadcastOpConversion
@@ -909,8 +900,10 @@ struct ViewOpConversion
struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::MakeRangeOp>::ConvertTritonGPUOpToLLVMPattern;
MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(converter,
benefit) {}
LogicalResult
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
@@ -918,13 +911,12 @@ struct MakeRangeOpConversion
Location loc = op->getLoc();
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
auto shape = rankedTy.getShape();
auto blocked_layout =
rankedTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto layout = rankedTy.getEncoding().cast<BlockedEncodingAttr>();
auto elemTy = rankedTy.getElementType();
assert(elemTy.isInteger(32));
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
auto idxs =
emitIndicesForBlockedLayout(loc, rewriter, blocked_layout, shape);
auto idxs = emitIndicesForBlockedLayout(loc, rewriter, layout, shape);
unsigned elems = idxs.size();
SmallVector<Value> retVals(elems);
for (auto multiDim : llvm::enumerate(idxs)) {
@@ -941,92 +933,116 @@ struct MakeRangeOpConversion
};
struct LoadOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp> {
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
LoadOpConversion(LLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value ptr = adaptor.ptr();
Value mask = adaptor.mask();
Value other = adaptor.other();
auto resultTy = op.result().getType().cast<RankedTensorType>();
auto blockedLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto shape = resultTy.getShape();
// TODO: Handle AxisInfo
// vecWidth = std::min(nts, aln)
// TODO: special processing for mma_first_row in legacy codes
assert(blockedLayout && "LoadOp only accepts blocked_layout");
unsigned vecWidth =
blockedLayout.getSizePerThread()[blockedLayout.getOrder()[0]];
Value ptr = op.ptr();
Value mask = op.mask();
Value other = op.other();
auto elemTy = resultTy.getElementType();
unsigned numElems = getElemsPerThread(blockedLayout, shape);
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter);
SmallVector<Value> maskVals;
if (mask) {
maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
Value llPtr = adaptor.ptr();
Value llMask = adaptor.mask();
Value llOther = adaptor.other();
auto loc = op->getLoc();
MLIRContext *ctx = rewriter.getContext();
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!valueTy)
return failure();
Type valueElemTy =
getTypeConverter()->convertType(valueTy.getElementType());
auto [layout, numElems] = getLayout(ptr);
auto ptrElems =
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
assert(ptrElems.size() == numElems);
SmallVector<Value> maskElems;
if (llMask) {
maskElems =
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
assert(ptrElems.size() == maskElems.size());
}
SmallVector<Value> otherVals;
if (other) {
otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
}
unsigned nbits = elemTy.isa<FloatType>()
? elemTy.cast<FloatType>().getWidth()
: elemTy.cast<IntegerType>().getWidth();
// unsigned dtsize = nbits / 8;
int max_word_width = std::max<int>(32, nbits);
int tot_width = nbits * vecWidth;
int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width);
// TODO: currently disable until supported in `store`
bool has_l2_evict_policy = false;
// Determine the vectorization size
size_t vec = getVectorizeSize(ptr, layout);
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNbits = dtsize * 8;
const int numVecs = numElems / vec;
// TODO: (goostavz) handle when other is const but not splat, which
// should be rarely seen
bool otherIsSplatConstInt = false;
DenseElementsAttr constAttr;
int64_t splatVal = 0;
if (elemTy.isa<IntegerType>() &&
if (valueElemTy.isa<IntegerType>() &&
matchPattern(op.other(), m_Constant(&constAttr)) &&
constAttr.isSplat()) {
otherIsSplatConstInt = true;
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
}
auto otherElems =
getLLVMElems(other, llOther, layout, getTypeConverter(), rewriter, loc);
SmallVector<Value> loadedVals;
for (size_t i = 0; i < numElems; i += vecWidth) {
Value ptr = ptrVals[i];
// TODO: Handle the optimization if ptr is from GEP and the idx is
// constant. This should be a canonicalization pattern in LLVM Dialect
unsigned in_off = 0;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
// TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0;
const int maxWordWidth = std::max<int>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width);
const int wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
const bool hasL2EvictPolicy = false;
PTXBuilder ptxBuilder;
auto &ld = *ptxBuilder.create<PtxIOInstr>("ld");
// TODO(Superjomn) Need to check masks before vectorize the load for all
// the values share one predicate? Here assume all the mask values are
// the same.
Value pred =
mask ? maskVals[i]
mask ? maskElems[vecStart]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1);
// ---
// create inline asm string
// ---
const std::string readConstrait =
const std::string readConstraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
const std::string writeConstrait =
const std::string writeConstraint =
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
PTXBuilder ptxBuilder;
PtxIOInstr &ld = *ptxBuilder.create<PtxIOInstr>("ld");
// prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand();
for (int i = 0; i < n_words; i++) {
auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
dstsOpr->listAppend(opr);
}
auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off);
auto *addrOpr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
// Define the instruction opcode
ld.predicate(pred, "b")
@@ -1037,11 +1053,12 @@ struct LoadOpConversion
.o("L1::evict_first",
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
.o("L1::cache_hint", has_l2_evict_policy)
.v(n_words)
.o("L1::cache_hint", hasL2EvictPolicy)
.v(nWords)
.b(width);
PTXBuilder::Operand *evictOpr{};
// Here lack a mlir::Value to bind to this operation, so disabled.
// if (has_l2_evict_policy)
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
@@ -1053,16 +1070,16 @@ struct LoadOpConversion
SmallVector<Value> others;
if (other) {
for (size_t ii = 0; ii < n_words; ii++) {
for (size_t ii = 0; ii < nWords; ii++) {
PTXInstr &mov = *ptxBuilder.create<>("mov");
mov.predicateNot(pred, "b").o("u", width);
size_t size = width / nbits;
size_t size = width / valueElemNbits;
auto vecTy = LLVM::getFixedVectorType(elemTy, size);
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (size_t s = 0; s < size; s++) {
Value falseVal = otherVals[i + ii * size + s];
Value falseVal = otherElems[vecStart + ii * size + s];
Value sVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
v = rewriter.create<LLVM::InsertElementOp>(loc, vecTy, v, falseVal,
@@ -1075,7 +1092,7 @@ struct LoadOpConversion
if (otherIsSplatConstInt) {
opr = ptxBuilder.newConstantOperand(splatVal);
} else {
opr = ptxBuilder.newOperand(v, readConstrait);
opr = ptxBuilder.newOperand(v, readConstraint);
others.push_back(v);
}
@@ -1086,7 +1103,7 @@ struct LoadOpConversion
// ---
// create inline ASM signature
// ---
SmallVector<Type> retTys(n_words, IntegerType::get(getContext(), width));
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
Type retTy = retTys.size() > 1
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
: retTys[0];
@@ -1097,7 +1114,8 @@ struct LoadOpConversion
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(),
/*asm_string=*/ptxBuilder.dump(),
/*constraints=*/ptxBuilder.getConstrains(), /*has_side_effects=*/true,
/*constraints=*/ptxBuilder.getConstrains(),
/*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
Value ret = inlineAsmOp.getResult(0);
@@ -1106,8 +1124,8 @@ struct LoadOpConversion
// extract and store return values
// ---
SmallVector<Value> rets;
for (unsigned int ii = 0; ii < n_words; ii++) {
Value curr = nullptr;
for (unsigned int ii = 0; ii < nWords; ii++) {
Value curr;
if (retTy.isa<LLVM::LLVMStructType>()) {
curr = rewriter.create<LLVM::ExtractValueOp>(
loc, IntegerType::get(getContext(), width), ret,
@@ -1116,19 +1134,21 @@ struct LoadOpConversion
curr = ret;
}
curr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::getFixedVectorType(elemTy, width / nbits), curr);
loc, LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
curr);
rets.push_back(curr);
}
int tmp = (width / nbits);
for (size_t ii = 0; ii < vecWidth; ii++) {
int tmp = (width / valueElemNbits);
for (size_t ii = 0; ii < vec; ii++) {
Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
Value loaded = rewriter.create<LLVM::ExtractElementOp>(
loc, elemTy, rets[ii / tmp], vecIdx);
loc, valueElemTy, rets[ii / tmp], vecIdx);
loadedVals.push_back(loaded);
}
} // end vec
Type llvmResultStructTy = getTypeConverter()->convertType(resultTy);
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
rewriter.replaceOp(op, {resultStruct});
@@ -1272,11 +1292,16 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::MulIOp, LLVM::MulOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
patterns.add<GEPOpConversion>(typeConverter, benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<LoadOpConversion>(typeConverter, benefit);
patterns.add<LoadOpConversion>(typeConverter, analysis, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);

View File

@@ -10,7 +10,6 @@ using namespace mlir::triton;
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr,
int numWarps) {
auto origType = ptr.getType().cast<RankedTensorType>();

View File

@@ -1649,6 +1649,10 @@ void init_triton_ir(py::module &&m) {
.def(
"add_sccp_pass",
[](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); })
.def("add_coalesce_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUCoalescePass());
})
.def("add_symbol_dce_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createSymbolDCEPass());

View File

@@ -29,7 +29,6 @@ def test_empty_kernel_cubin_compile():
def test_empty_kernel_launch():
device = torch.cuda.current_device()
binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32",
device=device,
constants={"BLOCK": 256},
num_warps=4,
num_stages=3)
@@ -38,11 +37,9 @@ def test_empty_kernel_launch():
)
A = torch.zeros([1024], device="cuda")
runtime.launch_kernel(fn=empty_kernel,
binary=binary,
runtime.launch_kernel(kernel=binary,
grid=grid,
num_warps=4,
num_stages=3,
device=device,
X=A,
stride_xm=256,
BLOCK=tl.constexpr(256))

View File

@@ -5,17 +5,12 @@ import triton
import triton.language as tl
import triton.runtime as runtime
NUM_WARPS = 4
BLOCK_SIZE = 256
# triton kernel
def test_vecadd_no_scf():
def vecadd_no_scf_tester(num_warps, block_size):
@triton.jit
def kernel(x_ptr, stride_xn,
y_ptr, stride_yn,
z_ptr, stride_zn,
def kernel(x_ptr,
y_ptr,
z_ptr,
BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -27,37 +22,35 @@ def test_vecadd_no_scf():
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
# TODO: add this to CI, to make sure the the compilation flow is at lease OK
# before we have GPU machines for CI.
# ptx, shem_size, kernel_name = triton.compile(kernel,
# "*fp32,i32,*fp32,i32,*fp32,i32",
# constants={"BLOCK_SIZE_N": 256},
# num_warps=NUM_WARPS,
# device=0, output="ptx")
torch.zeros([10], device=torch.device('cuda'))
device = torch.cuda.current_device()
binary = runtime.build_kernel(kernel, "*fp32,i32,*fp32,i32,*fp32,i32",
device=device,
constants={"BLOCK_SIZE_N": BLOCK_SIZE},
num_warps=NUM_WARPS,
binary = runtime.build_kernel(kernel, "*fp32,*fp32,*fp32,i32",
constants={"BLOCK_SIZE_N": block_size},
num_warps=num_warps,
num_stages=3)
grid = lambda META: (1, )
x = torch.randn((256,), device='cuda', dtype=torch.float32)
y = torch.randn((256,), device='cuda', dtype=torch.float32)
z = torch.empty((256,), device=x.device, dtype=x.dtype)
runtime.launch_kernel(fn=kernel,
binary=binary,
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
assert x.shape.numel() % block_size == 0, "Only test load without mask here"
grid = lambda EA: (x.shape.numel() // block_size,)
runtime.launch_kernel(kernel=binary,
grid=grid,
num_warps=NUM_WARPS,
num_stages=3,
device=device,
x_ptr=x,
stride_xn=x.stride(0),
y_ptr=y,
stride_yn=y.stride(0),
z_ptr=z,
stride_zn=z.stride(0),
BLOCK_SIZE_N=tl.constexpr(BLOCK_SIZE))
BLOCK_SIZE_N=tl.constexpr(block_size))
golden_z = x + y
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
def test_vecadd_no_scf():
vecadd_no_scf_tester(num_warps=2, block_size=256)
vecadd_no_scf_tester(num_warps=1, block_size=256)
if __name__ == '__main__':
test_vecadd_no_scf()

View File

@@ -798,7 +798,8 @@ def optimize_tritongpu_ir(mod, num_stages):
pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_canonicalizer_pass()
pm.add_cse_pass()
# pm.add_triton_gpu_combine_pass()
pm.add_coalesce_pass()
pm.add_triton_gpu_combine_pass()
pm.add_triton_gpu_verifier_pass()
pm.run(mod)
return mod

View File

@@ -8,7 +8,7 @@ import os
import subprocess
import tempfile
import textwrap
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import torch
@@ -256,41 +256,126 @@ class JITFunction:
return f"JITFunction({self.module}:{self.fn.__name__})"
def pow2_divisor(N):
if N % 16 == 0:
return 16
if N % 8 == 0:
return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1
class _KernelCache:
def __init__(self,
fn: JITFunction,
fn_type: str,
constants: Dict[str, Any],
num_warps: int = 4,
num_stages: int = 3):
# hold the arguments for building a kernel
self.fn = fn
self.fn_type = fn_type
self.constants = constants
self.num_warps = num_warps
self.num_stages = num_stages
# kernel compilation cache
self._binary_cache: Optional[LoadedBinary] = None
@property
def binary_cache(self):
return self._binary_cache
def set_binary_cache(self, binary: LoadedBinary):
assert binary
assert not self._binary_cache, "cannot set binary cache duplicately"
self._binary_cache = binary
def build_kernel(fn: JITFunction,
fn_type: str,
device: int,
constants: Dict[str, Any],
num_warps: int = 4,
num_stages: int = 3,
) -> LoadedBinary:
cubin, shem_size, kernel_name = compile(fn, fn_type, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, output="cubin")
assert cubin
assert kernel_name
backend = _triton.runtime.backend.CUDA
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size)
asm = dict(cubin=cubin)
binary = Binary(backend, kernel_name, asm, shem_size, num_warps)
loaded_binary = LoadedBinary(device, binary)
return loaded_binary
) -> _KernelCache:
return _KernelCache(fn, fn_type, constants, num_warps, num_stages)
def launch_kernel(fn: JITFunction, binary: LoadedBinary, grid, num_warps, num_stages, *wargs, **kwargs):
kwargs = {fn.arg_names.index(name): value for name, value in kwargs.items()}
torch_dtype_to_bytes = {
torch.int8: 1,
torch.uint8: 1,
torch.int16: 2,
torch.short: 2,
torch.int: 4,
torch.int32: 4,
torch.long: 8,
torch.int64: 8,
torch.float32: 4,
torch.float: 4,
torch.float16: 2,
torch.half: 2,
torch.bfloat16: 2,
# free to extend
}
def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs):
def is_tensor(arg):
return hasattr(arg, 'data_ptr') # a torch.tensor
# prepare function args for compile
kwargs = {kernel.fn.arg_names.index(name): value for name, value in kwargs.items()}
wargs = list(wargs)
for i, pos in enumerate(sorted(kwargs)):
wargs.insert(pos + i, kwargs[pos])
assert len(wargs) == len(fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(fn.arg_names), len(wargs))
assert len(wargs) == len(kernel.fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(kernel.fn.arg_names), len(wargs))
if not kernel.binary_cache:
# build the kernel cache
backend = _triton.runtime.backend.CUDA
attributes = dict()
for i, arg in enumerate(wargs):
if i in kernel.fn.do_not_specialize:
continue
if isinstance(arg, int):
attributes[i] = pow2_divisor(arg)
elif is_tensor(arg):
assert arg.dtype in torch_dtype_to_bytes
addr = arg.data_ptr()
range_size = _triton.runtime.get_pointer_range_size(addr)
divisibility = min(pow2_divisor(addr), pow2_divisor(range_size)) // torch_dtype_to_bytes[arg.dtype]
attributes[i] = divisibility
attributes_ = dict()
for i, value in attributes.items():
attributes_[kernel.fn.arg_names[i]] = value
cubin, shem_size, kernel_name = compile(kernel.fn, kernel.fn_type, device=device, constants=kernel.constants, attributes=attributes_, num_warps=kernel.num_warps, num_stages=kernel.num_stages, output="cubin")
assert cubin
assert kernel_name
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size)
asm = dict(cubin=cubin)
binary = Binary(backend, kernel_name, asm, shem_size, kernel.num_warps)
loaded_binary = LoadedBinary(device, binary)
kernel.set_binary_cache(loaded_binary)
device = torch.cuda.current_device()
torch.cuda.set_device(device)
stream = get_cuda_stream(device)
_triton.runtime.launch_binary(binary, wargs, fn.do_not_specialize, fn.arg_names,
stream, num_warps, num_stages, grid)
_triton.runtime.launch_binary(kernel.binary_cache, wargs, kernel.fn.do_not_specialize, kernel.fn.arg_names,
stream, kernel.num_warps, kernel.num_stages, grid)
# -----------------------------------------------------------------------------

View File

@@ -28,14 +28,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: vectorized_load
func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v4.b32
// CHECK-SAME: ld.global.b32
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v4.b32
// CHECK-SAME: ld.global.b32
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
}
@@ -43,14 +43,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: vectorized_load_f16
func @vectorized_load_f16(%a_ptr_init : tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v2.b32
// CHECK-SAME: ld.global.b16
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v2.b32
// CHECK-SAME: ld.global.b16
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0>
return
}
@@ -59,7 +59,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
// TODO: Pending on the support of isSplat constant
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other
func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
@@ -69,6 +69,40 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: kernel__Pfp32_Pfp32_Pfp32_i32__3c256
func @kernel__Pfp32_Pfp32_Pfp32_i32__3c256(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.getelementptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// CHECK: ld.global.v4.b32
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
// CHECK: ld.global.v4.b32
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.getelementptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Store 4 elements to global
// CHECK: st.global.b32.v4
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
}
}
// TODO: Add a testcase to verify the optimization when ptr of the LoadOp
// is from a GEP with const idx
@@ -99,7 +133,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_make_range
func @basic_make_range() {