[Backend] Vectorize Load/Store Ops (#86)

This PR does the following things:

- Code refactoring on Load and Store op codegen, rewrite with same logic
and share much code
- Support the vectorized load/store
This commit is contained in:
Yan Chunwei
2022-09-07 03:28:09 +08:00
committed by GitHub
parent 35e346bcff
commit a9464f4993
10 changed files with 433 additions and 295 deletions

View File

@@ -27,10 +27,10 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
if (isSharedEncoding(result)) { if (isSharedEncoding(result)) {
aliasInfo.insert(result); aliasInfo.insert(result);
pessimistic = false; pessimistic = false;
}
} else { } else {
llvm::errs() << "op: " << op->getName() << "\n"; llvm::errs() << "op: " << op->getName() << "\n";
} }
}
// XXX(Keren): triton ops don't support aliasing yet. // XXX(Keren): triton ops don't support aliasing yet.
// else if (auto viewOp = dyn_cast<triton::ViewOp>(op) || // else if (auto viewOp = dyn_cast<triton::ViewOp>(op) ||
// dyn_cast<triton::ExpandDimsOp>(op)) { // dyn_cast<triton::ExpandDimsOp>(op)) {

View File

@@ -1,5 +1,6 @@
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
@@ -307,16 +308,7 @@ static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
return linear_index; return linear_index;
} }
template <typename SourceOp> struct ConvertTritonGPUOpToLLVMPatternBase {
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
SmallVector<Value> SmallVector<Value>
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems, getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
@@ -329,18 +321,18 @@ public:
} }
return results; return results;
} }
};
Value getStructFromElements(Location loc, ValueRange resultVals, template <typename SourceOp>
ConversionPatternRewriter &rewriter, class ConvertTritonGPUOpToLLVMPattern
Type structType) const { : public ConvertOpToLLVMPattern<SourceOp>,
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType); public ConvertTritonGPUOpToLLVMPatternBase {
for (auto v : llvm::enumerate(resultVals)) { public:
llvmStruct = rewriter.create<LLVM::InsertValueOp>( using OpAdaptor = typename SourceOp::Adaptor;
loc, structType, llvmStruct, v.value(),
rewriter.getI64ArrayAttr(v.index())); explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
} PatternBenefit benefit = 1)
return llvmStruct; : ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter, SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear, Location loc, Value linear,
@@ -432,24 +424,18 @@ public:
for (unsigned blockOffset = 0; for (unsigned blockOffset = 0;
blockOffset < blockOffset <
shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]); shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]);
++blockOffset) { ++blockOffset)
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
++warpOffset) {
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k]; for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
++threadOffset) { ++threadOffset)
for (unsigned elemOffset = 0; elemOffset < sizePerThread[k]; for (unsigned elemOffset = 0; elemOffset < sizePerThread[k];
++elemOffset) { ++elemOffset)
offset[k].push_back(blockOffset * sizePerThread[k] * offset[k].push_back(blockOffset * sizePerThread[k] *
threadsPerWarp[k] * warpsPerCTA[k] + threadsPerWarp[k] * warpsPerCTA[k] +
warpOffset * sizePerThread[k] * warpOffset * sizePerThread[k] *
threadsPerWarp[k] + threadsPerWarp[k] +
threadOffset * sizePerThread[k] + elemOffset); threadOffset * sizePerThread[k] + elemOffset);
} }
}
}
}
}
// step 3, add offset to base, and reorder the sequence of indices, // step 3, add offset to base, and reorder the sequence of indices,
// to guarantee that elems in a same sizePerThread are adjacent in // to guarantee that elems in a same sizePerThread are adjacent in
// order // order
@@ -535,9 +521,9 @@ struct SplatOpConversion
} }
}; };
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), the // This pattern helps to convert arith::ConstantOp(with SplatElementsAttr),
// logic is the same as triton::SplatOp, so the underlying implementation is // the logic is the same as triton::SplatOp, so the underlying implementation
// reused. // is reused.
struct ArithConstantSplatOpConversion struct ArithConstantSplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> { : public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> {
using ConvertTritonGPUOpToLLVMPattern< using ConvertTritonGPUOpToLLVMPattern<
@@ -576,20 +562,104 @@ struct ArithConstantSplatOpConversion
} }
}; };
// Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
: AxisAnalysisPass(axisAnalysisPass) {}
// Get corresponding LLVM element values of \param value.
SmallVector<Value> getLLVMElems(Value value, Value llValue,
const BlockedEncodingAttr &layout,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Location loc) const {
if (!value)
return {};
auto ty = value.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
unsigned valueElems = getElemsPerThread(layout, shape);
auto llvmElemTy = typeConverter->convertType(ty.getElementType());
auto llvmElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter);
return valueVals;
}
// Get the blocked layout.
std::tuple<BlockedEncodingAttr, unsigned> getLayout(Value val) const {
auto ty = val.getType().cast<RankedTensorType>();
// Here, we assume that all inputs should have a blockedLayout
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto shape = ty.getShape();
unsigned valueElems = getElemsPerThread(layout, shape);
return std::make_tuple(layout, valueElems);
}
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
auto axisInfo = getAxisInfo(val);
auto order = layout.getOrder();
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
unsigned maxContig = axisInfo->getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
}
unsigned getVectorizeSize(Value ptr,
const BlockedEncodingAttr &layout) const {
auto axisInfo = getAxisInfo(ptr);
auto contig = axisInfo->getContiguity();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = layout.getOrder();
unsigned align = getAlignment(ptr, layout);
auto getTensorShape = [](Value val) -> ArrayRef<int64_t> {
auto ty = val.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
return shape;
};
// unsigned contigPerThread = layout.getSizePerThread()[order[0]];
unsigned contigPerThread = getElemsPerThread(layout, getTensorShape(ptr));
unsigned vec = std::min(align, contigPerThread);
return vec;
}
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
return it->getValue();
}
return llvm::Optional<AxisInfo>{};
}
protected:
AxisInfoAnalysis &AxisAnalysisPass;
};
struct StoreOpConversion struct StoreOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp> { : public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern< using ConvertTritonGPUOpToLLVMPattern<
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern; triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
StoreOpConversion(LLVMTypeConverter &converter, StoreOpConversion(LLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit), : ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
AxisAnalysisPass(axisAnalysisPass) {} LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value ptr = op.ptr(); Value ptr = op.ptr();
Value mask = op.mask(); Value mask = op.mask();
Value value = op.value(); Value value = op.value();
@@ -598,125 +668,52 @@ struct StoreOpConversion
Value llMask = adaptor.mask(); Value llMask = adaptor.mask();
Value llValue = adaptor.value(); Value llValue = adaptor.value();
auto loc = op->getLoc();
MLIRContext *ctx = rewriter.getContext();
auto valueTy = value.getType().dyn_cast<RankedTensorType>(); auto valueTy = value.getType().dyn_cast<RankedTensorType>();
if (!valueTy) if (!valueTy)
return failure(); return failure();
Type valueElemTy = Type valueElemTy =
getTypeConverter()->convertType(valueTy.getElementType()); getTypeConverter()->convertType(valueTy.getElementType());
MLIRContext *ctx = rewriter.getContext(); auto [layout, numElems] = getLayout(ptr);
auto loc = op->getLoc();
auto getLLVMElems = auto ptrElems =
[&](Value value, Value llValue, getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
const BlockedEncodingAttr &layout) -> SmallVector<Value> { auto valueElems =
auto ty = value.getType().cast<RankedTensorType>(); getLLVMElems(value, llValue, layout, getTypeConverter(), rewriter, loc);
auto shape = ty.getShape(); assert(ptrElems.size() == valueElems.size());
// Here, we assume that all inputs should have a blockedLayout
unsigned valueElems = getElemsPerThread(layout, shape);
auto llvmElemTy = getTypeConverter()->convertType(ty.getElementType());
auto llvmElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
auto valueVals =
getElementsFromStruct(loc, llValue, valueElems, rewriter);
return valueVals;
};
auto getLayout =
[&](Value val) -> std::tuple<BlockedEncodingAttr, unsigned> {
auto ty = val.getType().cast<RankedTensorType>();
auto shape = ty.getShape();
// Here, we assume that all inputs should have a blockedLayout
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
unsigned valueElems = getElemsPerThread(layout, shape);
return std::make_tuple(layout, valueElems);
};
auto [ptrLayout, ptrNumElems] = getLayout(ptr);
auto [valueLayout, valueNumElems] = getLayout(value);
auto ptrElems = getLLVMElems(ptr, llPtr, ptrLayout);
auto valueElems = getLLVMElems(value, llValue, valueLayout);
SmallVector<Value> maskElems; SmallVector<Value> maskElems;
if (llMask) { if (llMask) {
auto [maskLayout, maskNumElems] = getLayout(mask); maskElems =
maskElems = getLLVMElems(mask, llMask, maskLayout); getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
assert(valueElems.size() == maskElems.size()); assert(valueElems.size() == maskElems.size());
} }
auto getAlign = [this](Value val,
const BlockedEncodingAttr &layout) -> unsigned {
auto axisInfo = getAxisInfo(val);
assert(axisInfo.hasValue());
auto order = layout.getOrder();
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
unsigned maxContig = axisInfo->getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
return alignment;
};
// get align
auto getVec = [this,
&getAlign](Value val,
const BlockedEncodingAttr &layout) -> unsigned {
auto axisInfo = getAxisInfo(val);
auto contig = axisInfo->getContiguity();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = layout.getOrder();
unsigned align = getAlign(val, layout);
assert(!order.empty());
// Is this right?
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
unsigned vec = std::min(align, contigPerThread);
// TODO(Superjomn) Consider the is_mma_first_row in the legacy code
bool isMMAFirstRow = false;
if (isMMAFirstRow)
vec = std::min<size_t>(2, align);
return vec;
};
// Determine the vectorization size // Determine the vectorization size
size_t vec = getVec(ptr, ptrLayout); size_t vec = getVectorizeSize(ptr, layout);
const size_t dtsize = value.getType() const size_t dtsize =
.cast<RankedTensorType>() std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
.getElementType()
.getIntOrFloatBitWidth() /
8;
const size_t valueElemNbits = dtsize * 8; const size_t valueElemNbits = dtsize * 8;
const int numVecs = ptrNumElems / vec; const int numVecs = numElems / vec;
for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) { for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
// TODO: optimization when ptr is GEP with constant offset // TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0; size_t in_off = 0;
// pack sub-words (< 32/64bits) into words
// each load has width min(nbits*vec, 32/64)
// and there are (nbits * vec)/width of them
const int maxWordWidth = std::max<int>(32, valueElemNbits); const int maxWordWidth = std::max<int>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec; const int totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth); const int width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width); const int nWords = std::max(1, totalWidth / width);
const int wordNElems = width / valueElemNbits; const int wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits; const int vecNElems = totalWidth / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
assert(wordNElems * nWords * numVecs == valueElems.size()); // TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
// TODO(Superjomn) Add cache policy to store.
// TODO(Superjomn) deal with cache policy.
const bool hasL2EvictPolicy = false; const bool hasL2EvictPolicy = false;
PTXBuilder ptxBuilder; PTXBuilder ptxBuilder;
@@ -733,8 +730,9 @@ struct StoreOpConversion
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy); Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
// Insert each value element to the composition // Insert each value element to the composition
for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) { for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) {
Value elem = const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
valueElems[vecIdx * vecNElems + wordIdx * wordNElems + elemIdx]; assert(elemOffset < valueElems.size());
Value elem = valueElems[elemOffset];
if (elem.getType().isInteger(1)) if (elem.getType().isInteger(1))
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem); elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = rewriter.create<LLVM::BitcastOp>(loc, valueElemTy, elem); elem = rewriter.create<LLVM::BitcastOp>(loc, valueElemTy, elem);
@@ -751,13 +749,17 @@ struct StoreOpConversion
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint)); asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
} }
// TODO(Superjomn) Need to check masks before vectorize the load for all
// the values share one predicate? Here assume all the mask values are
// the same.
Value maskVal = Value maskVal =
llMask ? maskElems[vecIdx] llMask ? maskElems[vecStart]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(), : createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1); rewriter.getIntegerType(1), 1);
ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords); ptxStoreInstr.predicate(maskVal, "b").global().b(width).v(nWords);
auto *asmAddr = ptxBuilder.newAddrOperand(ptrElems[vecIdx], "l", in_off); auto *asmAddr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
ptxStoreInstr(asmAddr, asmArgList); ptxStoreInstr(asmAddr, asmArgList);
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1)); Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
@@ -782,17 +784,6 @@ struct StoreOpConversion
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
} }
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
return it->getValue();
}
return llvm::Optional<AxisInfo>{};
}
private:
AxisInfoAnalysis &AxisAnalysisPass;
}; };
struct BroadcastOpConversion struct BroadcastOpConversion
@@ -909,8 +900,10 @@ struct ViewOpConversion
struct MakeRangeOpConversion struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> { : public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::MakeRangeOp>::ConvertTritonGPUOpToLLVMPattern; MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(converter,
benefit) {}
LogicalResult LogicalResult
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
@@ -918,13 +911,12 @@ struct MakeRangeOpConversion
Location loc = op->getLoc(); Location loc = op->getLoc();
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>(); auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
auto shape = rankedTy.getShape(); auto shape = rankedTy.getShape();
auto blocked_layout = auto layout = rankedTy.getEncoding().cast<BlockedEncodingAttr>();
rankedTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto elemTy = rankedTy.getElementType(); auto elemTy = rankedTy.getElementType();
assert(elemTy.isInteger(32)); assert(elemTy.isInteger(32));
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start()); Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
auto idxs = auto idxs = emitIndicesForBlockedLayout(loc, rewriter, layout, shape);
emitIndicesForBlockedLayout(loc, rewriter, blocked_layout, shape);
unsigned elems = idxs.size(); unsigned elems = idxs.size();
SmallVector<Value> retVals(elems); SmallVector<Value> retVals(elems);
for (auto multiDim : llvm::enumerate(idxs)) { for (auto multiDim : llvm::enumerate(idxs)) {
@@ -941,92 +933,116 @@ struct MakeRangeOpConversion
}; };
struct LoadOpConversion struct LoadOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp> { : public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern< using ConvertTritonGPUOpToLLVMPattern<
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern; triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
LoadOpConversion(LLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value ptr = adaptor.ptr();
Value mask = adaptor.mask();
Value other = adaptor.other();
auto resultTy = op.result().getType().cast<RankedTensorType>();
auto blockedLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto shape = resultTy.getShape();
// TODO: Handle AxisInfo Value ptr = op.ptr();
// vecWidth = std::min(nts, aln) Value mask = op.mask();
// TODO: special processing for mma_first_row in legacy codes Value other = op.other();
assert(blockedLayout && "LoadOp only accepts blocked_layout");
unsigned vecWidth =
blockedLayout.getSizePerThread()[blockedLayout.getOrder()[0]];
auto elemTy = resultTy.getElementType(); Value llPtr = adaptor.ptr();
unsigned numElems = getElemsPerThread(blockedLayout, shape); Value llMask = adaptor.mask();
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter); Value llOther = adaptor.other();
SmallVector<Value> maskVals;
if (mask) { auto loc = op->getLoc();
maskVals = getElementsFromStruct(loc, mask, numElems, rewriter); MLIRContext *ctx = rewriter.getContext();
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!valueTy)
return failure();
Type valueElemTy =
getTypeConverter()->convertType(valueTy.getElementType());
auto [layout, numElems] = getLayout(ptr);
auto ptrElems =
getLLVMElems(ptr, llPtr, layout, getTypeConverter(), rewriter, loc);
assert(ptrElems.size() == numElems);
SmallVector<Value> maskElems;
if (llMask) {
maskElems =
getLLVMElems(mask, llMask, layout, getTypeConverter(), rewriter, loc);
assert(ptrElems.size() == maskElems.size());
} }
SmallVector<Value> otherVals;
if (other) { // Determine the vectorization size
otherVals = getElementsFromStruct(loc, other, numElems, rewriter); size_t vec = getVectorizeSize(ptr, layout);
}
unsigned nbits = elemTy.isa<FloatType>() const size_t dtsize =
? elemTy.cast<FloatType>().getWidth() std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
: elemTy.cast<IntegerType>().getWidth(); const size_t valueElemNbits = dtsize * 8;
// unsigned dtsize = nbits / 8;
int max_word_width = std::max<int>(32, nbits); const int numVecs = numElems / vec;
int tot_width = nbits * vecWidth;
int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width);
// TODO: currently disable until supported in `store`
bool has_l2_evict_policy = false;
// TODO: (goostavz) handle when other is const but not splat, which // TODO: (goostavz) handle when other is const but not splat, which
// should be rarely seen // should be rarely seen
bool otherIsSplatConstInt = false; bool otherIsSplatConstInt = false;
DenseElementsAttr constAttr; DenseElementsAttr constAttr;
int64_t splatVal = 0; int64_t splatVal = 0;
if (elemTy.isa<IntegerType>() && if (valueElemTy.isa<IntegerType>() &&
matchPattern(op.other(), m_Constant(&constAttr)) && matchPattern(op.other(), m_Constant(&constAttr)) &&
constAttr.isSplat()) { constAttr.isSplat()) {
otherIsSplatConstInt = true; otherIsSplatConstInt = true;
splatVal = constAttr.getSplatValue<APInt>().getSExtValue(); splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
} }
auto otherElems =
getLLVMElems(other, llOther, layout, getTypeConverter(), rewriter, loc);
SmallVector<Value> loadedVals; SmallVector<Value> loadedVals;
for (size_t i = 0; i < numElems; i += vecWidth) { for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
Value ptr = ptrVals[i]; // TODO: optimization when ptr is GEP with constant offset
// TODO: Handle the optimization if ptr is from GEP and the idx is size_t in_off = 0;
// constant. This should be a canonicalization pattern in LLVM Dialect
unsigned in_off = 0; const int maxWordWidth = std::max<int>(32, valueElemNbits);
const int totalWidth = valueElemNbits * vec;
const int width = std::min(totalWidth, maxWordWidth);
const int nWords = std::max(1, totalWidth / width);
const int wordNElems = width / valueElemNbits;
const int vecNElems = totalWidth / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
const bool hasL2EvictPolicy = false;
PTXBuilder ptxBuilder;
auto &ld = *ptxBuilder.create<PtxIOInstr>("ld");
// TODO(Superjomn) Need to check masks before vectorize the load for all
// the values share one predicate? Here assume all the mask values are
// the same.
Value pred = Value pred =
mask ? maskVals[i] mask ? maskElems[vecStart]
: createLLVMIntegerConstant(rewriter, loc, getTypeConverter(), : createLLVMIntegerConstant(rewriter, loc, getTypeConverter(),
rewriter.getIntegerType(1), 1); rewriter.getIntegerType(1), 1);
// --- const std::string readConstraint =
// create inline asm string
// ---
const std::string readConstrait =
(width == 64) ? "l" : ((width == 32) ? "r" : "c"); (width == 64) ? "l" : ((width == 32) ? "r" : "c");
const std::string writeConstrait = const std::string writeConstraint =
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
PTXBuilder ptxBuilder;
PtxIOInstr &ld = *ptxBuilder.create<PtxIOInstr>("ld");
// prepare asm operands // prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand(); auto *dstsOpr = ptxBuilder.newListOperand();
for (int i = 0; i < n_words; i++) { for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
auto *opr = ptxBuilder.newOperand(writeConstrait); // =r operations auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
dstsOpr->listAppend(opr); dstsOpr->listAppend(opr);
} }
auto *addrOpr = ptxBuilder.newAddrOperand(ptr, "l", in_off);
auto *addrOpr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
// Define the instruction opcode // Define the instruction opcode
ld.predicate(pred, "b") ld.predicate(pred, "b")
@@ -1037,11 +1053,12 @@ struct LoadOpConversion
.o("L1::evict_first", .o("L1::evict_first",
op.evict() == triton::EvictionPolicy::EVICT_FIRST) op.evict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST) .o("L1::evict_last", op.evict() == triton::EvictionPolicy::EVICT_LAST)
.o("L1::cache_hint", has_l2_evict_policy) .o("L1::cache_hint", hasL2EvictPolicy)
.v(n_words) .v(nWords)
.b(width); .b(width);
PTXBuilder::Operand *evictOpr{}; PTXBuilder::Operand *evictOpr{};
// Here lack a mlir::Value to bind to this operation, so disabled. // Here lack a mlir::Value to bind to this operation, so disabled.
// if (has_l2_evict_policy) // if (has_l2_evict_policy)
// evictOpr = ptxBuilder.newOperand(l2Evict, "l"); // evictOpr = ptxBuilder.newOperand(l2Evict, "l");
@@ -1053,16 +1070,16 @@ struct LoadOpConversion
SmallVector<Value> others; SmallVector<Value> others;
if (other) { if (other) {
for (size_t ii = 0; ii < n_words; ii++) { for (size_t ii = 0; ii < nWords; ii++) {
PTXInstr &mov = *ptxBuilder.create<>("mov"); PTXInstr &mov = *ptxBuilder.create<>("mov");
mov.predicateNot(pred, "b").o("u", width); mov.predicateNot(pred, "b").o("u", width);
size_t size = width / nbits; size_t size = width / valueElemNbits;
auto vecTy = LLVM::getFixedVectorType(elemTy, size); auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy); Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (size_t s = 0; s < size; s++) { for (size_t s = 0; s < size; s++) {
Value falseVal = otherVals[i + ii * size + s]; Value falseVal = otherElems[vecStart + ii * size + s];
Value sVal = createIndexAttrConstant( Value sVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), s); rewriter, loc, this->getTypeConverter()->getIndexType(), s);
v = rewriter.create<LLVM::InsertElementOp>(loc, vecTy, v, falseVal, v = rewriter.create<LLVM::InsertElementOp>(loc, vecTy, v, falseVal,
@@ -1075,7 +1092,7 @@ struct LoadOpConversion
if (otherIsSplatConstInt) { if (otherIsSplatConstInt) {
opr = ptxBuilder.newConstantOperand(splatVal); opr = ptxBuilder.newConstantOperand(splatVal);
} else { } else {
opr = ptxBuilder.newOperand(v, readConstrait); opr = ptxBuilder.newOperand(v, readConstraint);
others.push_back(v); others.push_back(v);
} }
@@ -1086,7 +1103,7 @@ struct LoadOpConversion
// --- // ---
// create inline ASM signature // create inline ASM signature
// --- // ---
SmallVector<Type> retTys(n_words, IntegerType::get(getContext(), width)); SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
Type retTy = retTys.size() > 1 Type retTy = retTys.size() > 1
? LLVM::LLVMStructType::getLiteral(getContext(), retTys) ? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
: retTys[0]; : retTys[0];
@@ -1097,7 +1114,8 @@ struct LoadOpConversion
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>( auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(), loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(),
/*asm_string=*/ptxBuilder.dump(), /*asm_string=*/ptxBuilder.dump(),
/*constraints=*/ptxBuilder.getConstrains(), /*has_side_effects=*/true, /*constraints=*/ptxBuilder.getConstrains(),
/*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr()); /*operand_attrs=*/ArrayAttr());
Value ret = inlineAsmOp.getResult(0); Value ret = inlineAsmOp.getResult(0);
@@ -1106,8 +1124,8 @@ struct LoadOpConversion
// extract and store return values // extract and store return values
// --- // ---
SmallVector<Value> rets; SmallVector<Value> rets;
for (unsigned int ii = 0; ii < n_words; ii++) { for (unsigned int ii = 0; ii < nWords; ii++) {
Value curr = nullptr; Value curr;
if (retTy.isa<LLVM::LLVMStructType>()) { if (retTy.isa<LLVM::LLVMStructType>()) {
curr = rewriter.create<LLVM::ExtractValueOp>( curr = rewriter.create<LLVM::ExtractValueOp>(
loc, IntegerType::get(getContext(), width), ret, loc, IntegerType::get(getContext(), width), ret,
@@ -1116,19 +1134,21 @@ struct LoadOpConversion
curr = ret; curr = ret;
} }
curr = rewriter.create<LLVM::BitcastOp>( curr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::getFixedVectorType(elemTy, width / nbits), curr); loc, LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits),
curr);
rets.push_back(curr); rets.push_back(curr);
} }
int tmp = (width / nbits); int tmp = (width / valueElemNbits);
for (size_t ii = 0; ii < vecWidth; ii++) { for (size_t ii = 0; ii < vec; ii++) {
Value vecIdx = createIndexAttrConstant( Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp); rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
Value loaded = rewriter.create<LLVM::ExtractElementOp>( Value loaded = rewriter.create<LLVM::ExtractElementOp>(
loc, elemTy, rets[ii / tmp], vecIdx); loc, valueElemTy, rets[ii / tmp], vecIdx);
loadedVals.push_back(loaded); loadedVals.push_back(loaded);
} }
} // end vec } // end vec
Type llvmResultStructTy = getTypeConverter()->convertType(resultTy);
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
Value resultStruct = Value resultStruct =
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy); getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
rewriter.replaceOp(op, {resultStruct}); rewriter.replaceOp(op, {resultStruct});
@@ -1272,11 +1292,16 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit); benefit);
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter, patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
benefit); benefit);
patterns.add<BinaryOpConversion<arith::MulIOp, LLVM::MulOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit); patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit); patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
patterns.add<GEPOpConversion>(typeConverter, benefit); patterns.add<GEPOpConversion>(typeConverter, benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit); patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<LoadOpConversion>(typeConverter, benefit); patterns.add<LoadOpConversion>(typeConverter, analysis, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit); patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit); patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit); patterns.add<SplatOpConversion>(typeConverter, benefit);

View File

@@ -10,7 +10,6 @@ using namespace mlir::triton;
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> { struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr, Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr,
int numWarps) { int numWarps) {
auto origType = ptr.getType().cast<RankedTensorType>(); auto origType = ptr.getType().cast<RankedTensorType>();

View File

@@ -1649,6 +1649,10 @@ void init_triton_ir(py::module &&m) {
.def( .def(
"add_sccp_pass", "add_sccp_pass",
[](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); }) [](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); })
.def("add_coalesce_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUCoalescePass());
})
.def("add_symbol_dce_pass", .def("add_symbol_dce_pass",
[](mlir::PassManager &self) { [](mlir::PassManager &self) {
self.addPass(mlir::createSymbolDCEPass()); self.addPass(mlir::createSymbolDCEPass());

View File

@@ -29,7 +29,6 @@ def test_empty_kernel_cubin_compile():
def test_empty_kernel_launch(): def test_empty_kernel_launch():
device = torch.cuda.current_device() device = torch.cuda.current_device()
binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32", binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32",
device=device,
constants={"BLOCK": 256}, constants={"BLOCK": 256},
num_warps=4, num_warps=4,
num_stages=3) num_stages=3)
@@ -38,11 +37,9 @@ def test_empty_kernel_launch():
) )
A = torch.zeros([1024], device="cuda") A = torch.zeros([1024], device="cuda")
runtime.launch_kernel(fn=empty_kernel, runtime.launch_kernel(kernel=binary,
binary=binary,
grid=grid, grid=grid,
num_warps=4, device=device,
num_stages=3,
X=A, X=A,
stride_xm=256, stride_xm=256,
BLOCK=tl.constexpr(256)) BLOCK=tl.constexpr(256))

View File

@@ -5,17 +5,12 @@ import triton
import triton.language as tl import triton.language as tl
import triton.runtime as runtime import triton.runtime as runtime
NUM_WARPS = 4
BLOCK_SIZE = 256
# triton kernel def vecadd_no_scf_tester(num_warps, block_size):
def test_vecadd_no_scf():
@triton.jit @triton.jit
def kernel(x_ptr, stride_xn, def kernel(x_ptr,
y_ptr, stride_yn, y_ptr,
z_ptr, stride_zn, z_ptr,
BLOCK_SIZE_N: tl.constexpr): BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -27,37 +22,35 @@ def test_vecadd_no_scf():
z_ptrs = z_ptr + offset z_ptrs = z_ptr + offset
tl.store(z_ptrs, z) tl.store(z_ptrs, z)
# TODO: add this to CI, to make sure the the compilation flow is at lease OK
# before we have GPU machines for CI.
# ptx, shem_size, kernel_name = triton.compile(kernel,
# "*fp32,i32,*fp32,i32,*fp32,i32",
# constants={"BLOCK_SIZE_N": 256},
# num_warps=NUM_WARPS,
# device=0, output="ptx")
torch.zeros([10], device=torch.device('cuda')) torch.zeros([10], device=torch.device('cuda'))
device = torch.cuda.current_device() device = torch.cuda.current_device()
binary = runtime.build_kernel(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", binary = runtime.build_kernel(kernel, "*fp32,*fp32,*fp32,i32",
device=device, constants={"BLOCK_SIZE_N": block_size},
constants={"BLOCK_SIZE_N": BLOCK_SIZE}, num_warps=num_warps,
num_warps=NUM_WARPS,
num_stages=3) num_stages=3)
grid = lambda META: (1, )
x = torch.randn((256,), device='cuda', dtype=torch.float32) x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
y = torch.randn((256,), device='cuda', dtype=torch.float32) y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
z = torch.empty((256,), device=x.device, dtype=x.dtype) z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
runtime.launch_kernel(fn=kernel,
binary=binary, assert x.shape.numel() % block_size == 0, "Only test load without mask here"
grid = lambda EA: (x.shape.numel() // block_size,)
runtime.launch_kernel(kernel=binary,
grid=grid, grid=grid,
num_warps=NUM_WARPS, device=device,
num_stages=3,
x_ptr=x, x_ptr=x,
stride_xn=x.stride(0),
y_ptr=y, y_ptr=y,
stride_yn=y.stride(0),
z_ptr=z, z_ptr=z,
stride_zn=z.stride(0), BLOCK_SIZE_N=tl.constexpr(block_size))
BLOCK_SIZE_N=tl.constexpr(BLOCK_SIZE))
golden_z = x + y golden_z = x + y
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
def test_vecadd_no_scf():
vecadd_no_scf_tester(num_warps=2, block_size=256)
vecadd_no_scf_tester(num_warps=1, block_size=256)
if __name__ == '__main__':
test_vecadd_no_scf()

View File

@@ -798,7 +798,8 @@ def optimize_tritongpu_ir(mod, num_stages):
pm.add_tritongpu_pipeline_pass(num_stages) pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_canonicalizer_pass() pm.add_canonicalizer_pass()
pm.add_cse_pass() pm.add_cse_pass()
# pm.add_triton_gpu_combine_pass() pm.add_coalesce_pass()
pm.add_triton_gpu_combine_pass()
pm.add_triton_gpu_verifier_pass() pm.add_triton_gpu_verifier_pass()
pm.run(mod) pm.run(mod)
return mod return mod

View File

@@ -8,7 +8,7 @@ import os
import subprocess import subprocess
import tempfile import tempfile
import textwrap import textwrap
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import torch import torch
@@ -256,41 +256,126 @@ class JITFunction:
return f"JITFunction({self.module}:{self.fn.__name__})" return f"JITFunction({self.module}:{self.fn.__name__})"
def pow2_divisor(N):
if N % 16 == 0:
return 16
if N % 8 == 0:
return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1
class _KernelCache:
def __init__(self,
fn: JITFunction,
fn_type: str,
constants: Dict[str, Any],
num_warps: int = 4,
num_stages: int = 3):
# hold the arguments for building a kernel
self.fn = fn
self.fn_type = fn_type
self.constants = constants
self.num_warps = num_warps
self.num_stages = num_stages
# kernel compilation cache
self._binary_cache: Optional[LoadedBinary] = None
@property
def binary_cache(self):
return self._binary_cache
def set_binary_cache(self, binary: LoadedBinary):
assert binary
assert not self._binary_cache, "cannot set binary cache duplicately"
self._binary_cache = binary
def build_kernel(fn: JITFunction, def build_kernel(fn: JITFunction,
fn_type: str, fn_type: str,
device: int,
constants: Dict[str, Any], constants: Dict[str, Any],
num_warps: int = 4, num_warps: int = 4,
num_stages: int = 3, num_stages: int = 3,
) -> LoadedBinary: ) -> _KernelCache:
cubin, shem_size, kernel_name = compile(fn, fn_type, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, output="cubin") return _KernelCache(fn, fn_type, constants, num_warps, num_stages)
torch_dtype_to_bytes = {
torch.int8: 1,
torch.uint8: 1,
torch.int16: 2,
torch.short: 2,
torch.int: 4,
torch.int32: 4,
torch.long: 8,
torch.int64: 8,
torch.float32: 4,
torch.float: 4,
torch.float16: 2,
torch.half: 2,
torch.bfloat16: 2,
# free to extend
}
def launch_kernel(kernel: _KernelCache, grid, device, *wargs, **kwargs):
def is_tensor(arg):
return hasattr(arg, 'data_ptr') # a torch.tensor
# prepare function args for compile
kwargs = {kernel.fn.arg_names.index(name): value for name, value in kwargs.items()}
wargs = list(wargs)
for i, pos in enumerate(sorted(kwargs)):
wargs.insert(pos + i, kwargs[pos])
assert len(wargs) == len(kernel.fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(kernel.fn.arg_names), len(wargs))
if not kernel.binary_cache:
# build the kernel cache
backend = _triton.runtime.backend.CUDA
attributes = dict()
for i, arg in enumerate(wargs):
if i in kernel.fn.do_not_specialize:
continue
if isinstance(arg, int):
attributes[i] = pow2_divisor(arg)
elif is_tensor(arg):
assert arg.dtype in torch_dtype_to_bytes
addr = arg.data_ptr()
range_size = _triton.runtime.get_pointer_range_size(addr)
divisibility = min(pow2_divisor(addr), pow2_divisor(range_size)) // torch_dtype_to_bytes[arg.dtype]
attributes[i] = divisibility
attributes_ = dict()
for i, value in attributes.items():
attributes_[kernel.fn.arg_names[i]] = value
cubin, shem_size, kernel_name = compile(kernel.fn, kernel.fn_type, device=device, constants=kernel.constants, attributes=attributes_, num_warps=kernel.num_warps, num_stages=kernel.num_stages, output="cubin")
assert cubin assert cubin
assert kernel_name assert kernel_name
backend = _triton.runtime.backend.CUDA
max_shared_memory = _triton.runtime.max_shared_memory(backend, device) max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size) assert shem_size <= max_shared_memory, "shared memory out of resource, max size is %d, but want %s" % (max_shared_memory, shem_size)
asm = dict(cubin=cubin) asm = dict(cubin=cubin)
binary = Binary(backend, kernel_name, asm, shem_size, num_warps) binary = Binary(backend, kernel_name, asm, shem_size, kernel.num_warps)
loaded_binary = LoadedBinary(device, binary) loaded_binary = LoadedBinary(device, binary)
return loaded_binary kernel.set_binary_cache(loaded_binary)
def launch_kernel(fn: JITFunction, binary: LoadedBinary, grid, num_warps, num_stages, *wargs, **kwargs):
kwargs = {fn.arg_names.index(name): value for name, value in kwargs.items()}
wargs = list(wargs)
for i, pos in enumerate(sorted(kwargs)):
wargs.insert(pos + i, kwargs[pos])
assert len(wargs) == len(fn.arg_names), "Function argument list not match, need %d but get %d args" % (len(fn.arg_names), len(wargs))
device = torch.cuda.current_device()
torch.cuda.set_device(device) torch.cuda.set_device(device)
stream = get_cuda_stream(device) stream = get_cuda_stream(device)
_triton.runtime.launch_binary(binary, wargs, fn.do_not_specialize, fn.arg_names, _triton.runtime.launch_binary(kernel.binary_cache, wargs, kernel.fn.do_not_specialize, kernel.fn.arg_names,
stream, num_warps, num_stages, grid) stream, kernel.num_warps, kernel.num_stages, grid)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -28,14 +28,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// ----- // -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: vectorized_load // CHECK-LABEL: vectorized_load
func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v4.b32 // CHECK-SAME: ld.global.b32
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v4.b32 // CHECK-SAME: ld.global.b32
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return return
} }
@@ -43,14 +43,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// ----- // -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: vectorized_load_f16 // CHECK-LABEL: vectorized_load_f16
func @vectorized_load_f16(%a_ptr_init : tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v2.b32 // CHECK-SAME: ld.global.b16
// CHECK: llvm.inline_asm // CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.v2.b32 // CHECK-SAME: ld.global.b16
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0> %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0>
return return
} }
@@ -59,7 +59,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// ----- // -----
// TODO: Pending on the support of isSplat constant // TODO: Pending on the support of isSplat constant
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other // CHECK-LABEL: masked_load_const_other
func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) { func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
@@ -69,6 +69,40 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
} }
} }
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: kernel__Pfp32_Pfp32_Pfp32_i32__3c256
func @kernel__Pfp32_Pfp32_Pfp32_i32__3c256(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.getelementptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// CHECK: ld.global.v4.b32
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
// CHECK: ld.global.v4.b32
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.getelementptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
// Store 4 elements to global
// CHECK: st.global.b32.v4
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
}
}
// TODO: Add a testcase to verify the optimization when ptr of the LoadOp // TODO: Add a testcase to verify the optimization when ptr of the LoadOp
// is from a GEP with const idx // is from a GEP with const idx
@@ -99,7 +133,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// ----- // -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_make_range // CHECK-LABEL: basic_make_range
func @basic_make_range() { func @basic_make_range() {