Merge remote-tracking branch 'origin/master' into phil/fused-attention-perf-fixup

This commit is contained in:
Phil Tillet
2022-12-30 11:53:49 -08:00
21 changed files with 765 additions and 337 deletions

View File

@@ -39,6 +39,8 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
SmallVector<unsigned> getOrder(const Attribute &layout); SmallVector<unsigned> getOrder(const Attribute &layout);
bool isaDistributedLayout(const Attribute &layout);
} // namespace gpu } // namespace gpu
} // namespace triton } // namespace triton
} // namespace mlir } // namespace mlir

View File

@@ -9,10 +9,12 @@ using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::getStructFromElements; using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper; using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr;
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout, bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
@@ -24,111 +26,63 @@ bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
dotOperandLayout.getParent() == mmaLayout; dotOperandLayout.getParent() == mmaLayout;
} }
void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides, void storeDistributedToShared(Value src, Value llSrc,
ArrayRef<Value> srcIndices, Value dst, Value smemBase, ArrayRef<Value> dstStrides,
Type elemTy, Location loc, ArrayRef<SmallVector<Value>> srcIndices,
Value dst, Value smemBase, Type elemTy,
Location loc,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
auto srcTy = src.getType().cast<RankedTensorType>(); auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape(); auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 && "Unexpected rank of insertSlice"); assert(srcShape.size() == 2 && "Unexpected rank of storeDistributedToShared");
auto dstTy = dst.getType().cast<RankedTensorType>(); auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>(); auto srcDistributedLayout = srcTy.getEncoding();
if (auto mmaLayout = srcDistributedLayout.dyn_cast<MmaEncodingAttr>()) {
assert((!mmaLayout.isVolta()) &&
"ConvertLayout MMAv1->Shared is not suppported yet");
}
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>(); auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
auto inOrd = srcBlockedLayout.getOrder(); auto inOrd = getOrder(srcDistributedLayout);
auto outOrd = dstSharedLayout.getOrder(); auto outOrd = dstSharedLayout.getOrder();
if (inOrd != outOrd)
llvm_unreachable(
"blocked -> shared with different order not yet implemented");
unsigned inVec = unsigned inVec =
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1; inOrd == outOrd ? getContigPerThread(srcDistributedLayout)[inOrd[0]] : 1;
unsigned outVec = dstSharedLayout.getVec(); unsigned outVec = dstSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec); unsigned minVec = std::min(outVec, inVec);
unsigned perPhase = dstSharedLayout.getPerPhase(); unsigned perPhase = dstSharedLayout.getPerPhase();
unsigned maxPhase = dstSharedLayout.getMaxPhase(); unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = getElemsPerThread(srcTy); unsigned numElems = getElemsPerThread(srcTy);
assert(numElems == srcIndices.size());
auto inVals = getElementsFromStruct(loc, llSrc, rewriter); auto inVals = getElementsFromStruct(loc, llSrc, rewriter);
auto srcAccumSizeInThreads =
product<unsigned>(srcBlockedLayout.getSizePerThread());
auto wordTy = vec_ty(elemTy, minVec); auto wordTy = vec_ty(elemTy, minVec);
auto elemPtrTy = ptr_ty(elemTy); auto elemPtrTy = ptr_ty(elemTy);
// TODO: [goostavz] We should make a cache for the calculation of
// emitBaseIndexForBlockedLayout in case backend compiler not being able to
// optimize that
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
// Visit each input value in the order they are placed in inVals
//
// Please note that the order was not awaring of blockLayout.getOrder(),
// thus the adjacent elems may not belong to a same word. This could be
// improved if we update the elements order by emitIndicesForBlockedLayout()
SmallVector<unsigned> wordsInEachRep(2);
wordsInEachRep[0] = inOrd[0] == 0
? srcBlockedLayout.getSizePerThread()[0] / minVec
: srcBlockedLayout.getSizePerThread()[0];
wordsInEachRep[1] = inOrd[0] == 0
? srcBlockedLayout.getSizePerThread()[1]
: srcBlockedLayout.getSizePerThread()[1] / minVec;
Value outVecVal = i32_val(outVec); Value outVecVal = i32_val(outVec);
Value minVecVal = i32_val(minVec); Value minVecVal = i32_val(minVec);
auto numWordsEachRep = product<unsigned>(wordsInEachRep); Value word;
SmallVector<Value> wordVecs(numWordsEachRep);
for (unsigned i = 0; i < numElems; ++i) { for (unsigned i = 0; i < numElems; ++i) {
if (i % srcAccumSizeInThreads == 0) { if (i % minVec == 0)
// start of a replication word = undef(wordTy);
for (unsigned w = 0; w < numWordsEachRep; ++w) { word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec));
wordVecs[w] = undef(wordTy); if (i % minVec == minVec - 1) {
}
}
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
auto wordVecIdx =
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
wordVecs[wordVecIdx] =
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], i32_val(pos));
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
// end of replication, store the vectors into shared memory
unsigned linearRepIdx = i / srcAccumSizeInThreads;
auto multiDimRepIdx =
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
++linearWordIdx) {
// step 1: recover the multidim_index from the index of // step 1: recover the multidim_index from the index of
// input_elements SmallVector<Value> multiDimIdx = srcIndices[i];
auto multiDimWordIdx = SmallVector<Value> dbgVal = srcIndices[i];
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
SmallVector<Value> multiDimIdx(2);
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] +
multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1);
multiDimIdx[0] = add(srcIndices[0], i32_val(wordOffset0));
multiDimIdx[1] = add(srcIndices[1], i32_val(wordOffset1));
// step 2: do swizzling // step 2: do swizzling
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal); Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal); multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
Value off_1 = mul(multiDimIdx[outOrd[1]], srcStrides[outOrd[1]]); Value off_1 = mul(multiDimIdx[outOrd[1]], dstStrides[outOrd[1]]);
Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase)); Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase));
phaseId = urem(phaseId, i32_val(maxPhase)); phaseId = urem(phaseId, i32_val(maxPhase));
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId); Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
off_0 = mul(off_0, outVecVal); off_0 = mul(off_0, outVecVal);
remained = udiv(remained, minVecVal); remained = udiv(remained, minVecVal);
off_0 = add(off_0, mul(remained, minVecVal)); off_0 = add(off_0, mul(remained, minVecVal));
Value offset = add(off_1, off_0); Value offset = add(off_1, mul(off_0, dstStrides[outOrd[0]]));
// step 3: store // step 3: store
Value smemAddr = gep(elemPtrTy, smemBase, offset); Value smemAddr = gep(elemPtrTy, smemBase, offset);
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
store(wordVecs[linearWordIdx], smemAddr); store(word, smemAddr);
}
} }
} }
} }
@@ -148,20 +102,15 @@ public:
auto dstTy = dst.getType().cast<RankedTensorType>(); auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding(); Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding();
if (srcLayout.isa<BlockedEncodingAttr>() && if (isaDistributedLayout(srcLayout) &&
dstLayout.isa<SharedEncodingAttr>()) { dstLayout.isa<SharedEncodingAttr>()) {
return lowerBlockedToShared(op, adaptor, rewriter); return lowerDistributedToShared(op, adaptor, rewriter);
} }
if (srcLayout.isa<SharedEncodingAttr>() && if (srcLayout.isa<SharedEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) { dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerSharedToDotOperand(op, adaptor, rewriter); return lowerSharedToDotOperand(op, adaptor, rewriter);
} }
if ((srcLayout.isa<BlockedEncodingAttr>() || if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
srcLayout.isa<MmaEncodingAttr>() ||
srcLayout.isa<SliceEncodingAttr>()) &&
(dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>())) {
return lowerDistributedToDistributed(op, adaptor, rewriter); return lowerDistributedToDistributed(op, adaptor, rewriter);
} }
if (srcLayout.isa<MmaEncodingAttr>() && if (srcLayout.isa<MmaEncodingAttr>() &&
@@ -182,7 +131,7 @@ private:
unsigned rank = shape.size(); unsigned rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
auto multiDimOffsetFirstElem = auto multiDimOffsetFirstElem =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape);
SmallVector<Value> multiDimOffset(rank); SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>( SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, getSizePerThread(layout), getOrder(layout)); elemId, getSizePerThread(layout), getOrder(layout));
@@ -479,7 +428,7 @@ private:
// Swizzling in shared memory to avoid bank conflict. Normally used for // Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots. // A/B operands of dots.
LogicalResult LogicalResult
lowerBlockedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, lowerDistributedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc(); auto loc = op.getLoc();
Value src = op.src(); Value src = op.src();
@@ -490,22 +439,20 @@ private:
auto dstShape = dstTy.getShape(); auto dstShape = dstTy.getShape();
assert(srcShape.size() == 2 && assert(srcShape.size() == 2 &&
"Unexpected rank of ConvertLayout(blocked->shared)"); "Unexpected rank of ConvertLayout(blocked->shared)");
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>(); auto srcLayout = srcTy.getEncoding();
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>(); auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
auto inOrd = srcBlockedLayout.getOrder(); auto inOrd = getOrder(srcLayout);
auto outOrd = dstSharedLayout.getOrder(); auto outOrd = dstSharedLayout.getOrder();
Value smemBase = getSharedMemoryBase(loc, rewriter, dst); Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bitcast(smemBase, elemPtrTy); smemBase = bitcast(smemBase, elemPtrTy);
auto srcStrides = auto dstStrides =
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter); getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
auto srcIndices = emitBaseIndexForBlockedLayout(loc, rewriter, auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
srcBlockedLayout, srcShape); storeDistributedToShared(src, adaptor.src(), dstStrides, srcIndices, dst,
storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst,
smemBase, elemTy, loc, rewriter); smemBase, elemTy, loc, rewriter);
auto smemObj = auto smemObj =
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter); SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
@@ -680,7 +627,9 @@ private:
void populateConvertLayoutOpToLLVMPatterns( void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, PatternBenefit benefit) { const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem, patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit); indexCacheInfo, benefit);
} }

View File

@@ -11,14 +11,18 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout, bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
DotOperandEncodingAttr &dotOperandLayout); DotOperandEncodingAttr &dotOperandLayout);
void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides, void storeDistributedToShared(Value src, Value llSrc,
ArrayRef<Value> srcIndices, Value dst, Value smemBase, ArrayRef<Value> srcStrides,
Type elemPtrTy, Location loc, ArrayRef<SmallVector<Value>> srcIndices,
Value dst, Value smemBase, Type elemPtrTy,
Location loc,
ConversionPatternRewriter &rewriter); ConversionPatternRewriter &rewriter);
void populateConvertLayoutOpToLLVMPatterns( void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, PatternBenefit benefit); const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);
#endif #endif

View File

@@ -204,7 +204,12 @@ struct DotOpMmaV1ConversionHelper {
offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1)); offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
} }
Type f16x2Ty = vec_ty(f16_ty, 2); Type elemX2Ty = vec_ty(f16_ty, 2);
Type elemPtrTy = ptr_ty(f16_ty);
if (tensorTy.getElementType().isBF16()) {
elemX2Ty = vec_ty(i16_ty, 2);
elemPtrTy = ptr_ty(i16_ty);
}
// prepare arguments // prepare arguments
SmallVector<Value> ptrA(numPtrA); SmallVector<Value> ptrA(numPtrA);
@@ -213,30 +218,28 @@ struct DotOpMmaV1ConversionHelper {
for (int i = 0; i < numPtrA; i++) for (int i = 0; i < numPtrA; i++)
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]); ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
Type f16PtrTy = ptr_ty(f16_ty);
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
vals[{m, k}] = {val0, val1}; vals[{m, k}] = {val0, val1};
}; };
auto loadA = [&](int m, int k) { auto loadA = [&](int m, int k) {
int offidx = (isARow ? k / 4 : m) % numPtrA; int offidx = (isARow ? k / 4 : m) % numPtrA;
Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]); Value thePtrA = gep(elemPtrTy, smemBase, offA[offidx]);
int stepAM = isARow ? m : m / numPtrA * numPtrA; int stepAM = isARow ? m : m / numPtrA * numPtrA;
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM), Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
mul(i32_val(stepAK), strideAK)); mul(i32_val(stepAK), strideAK));
Value pa = gep(f16PtrTy, thePtrA, offset); Value pa = gep(elemPtrTy, thePtrA, offset);
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3); Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
Value ha = load(bitcast(pa, aPtrTy)); Value ha = load(bitcast(pa, aPtrTy));
// record lds that needs to be moved // record lds that needs to be moved
Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty); Value ha00 = bitcast(extract_element(ha, i32_val(0)), elemX2Ty);
Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty); Value ha01 = bitcast(extract_element(ha, i32_val(1)), elemX2Ty);
ld(has, m, k, ha00, ha01); ld(has, m, k, ha00, ha01);
if (vecA > 4) { if (vecA > 4) {
Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty); Value ha10 = bitcast(extract_element(ha, i32_val(2)), elemX2Ty);
Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty); Value ha11 = bitcast(extract_element(ha, i32_val(3)), elemX2Ty);
if (isARow) if (isARow)
ld(has, m, k + 4, ha10, ha11); ld(has, m, k + 4, ha10, ha11);
else else
@@ -256,7 +259,7 @@ struct DotOpMmaV1ConversionHelper {
elems.push_back(item.second.second); elems.push_back(item.second.second);
} }
Type resTy = struct_ty(SmallVector<Type>(elems.size(), f16x2Ty)); Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
Value res = getStructFromElements(loc, elems, rewriter, resTy); Value res = getStructFromElements(loc, elems, rewriter, resTy);
return res; return res;
} }
@@ -319,8 +322,12 @@ struct DotOpMmaV1ConversionHelper {
offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1)); offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
} }
Type f16PtrTy = ptr_ty(f16_ty); Type elemPtrTy = ptr_ty(f16_ty);
Type f16x2Ty = vec_ty(f16_ty, 2); Type elemX2Ty = vec_ty(f16_ty, 2);
if (tensorTy.getElementType().isBF16()) {
elemPtrTy = ptr_ty(i16_ty);
elemX2Ty = vec_ty(i16_ty, 2);
}
SmallVector<Value> ptrB(numPtrB); SmallVector<Value> ptrB(numPtrB);
ValueTable hbs; ValueTable hbs;
@@ -339,17 +346,17 @@ struct DotOpMmaV1ConversionHelper {
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN), Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
mul(i32_val(stepBK), strideBK)); mul(i32_val(stepBK), strideBK));
Value pb = gep(f16PtrTy, thePtrB, offset); Value pb = gep(elemPtrTy, thePtrB, offset);
Value hb = Value hb =
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
// record lds that needs to be moved // record lds that needs to be moved
Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty); Value hb00 = bitcast(extract_element(hb, i32_val(0)), elemX2Ty);
Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty); Value hb01 = bitcast(extract_element(hb, i32_val(1)), elemX2Ty);
ld(hbs, n, K, hb00, hb01); ld(hbs, n, K, hb00, hb01);
if (vecB > 4) { if (vecB > 4) {
Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty); Value hb10 = bitcast(extract_element(hb, i32_val(2)), elemX2Ty);
Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty); Value hb11 = bitcast(extract_element(hb, i32_val(3)), elemX2Ty);
if (isBRow) if (isBRow)
ld(hbs, n + 1, K, hb10, hb11); ld(hbs, n + 1, K, hb10, hb11);
else else
@@ -369,8 +376,7 @@ struct DotOpMmaV1ConversionHelper {
elems.push_back(item.second.first); elems.push_back(item.second.first);
elems.push_back(item.second.second); elems.push_back(item.second.second);
} }
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
Type resTy = struct_ty(SmallVector<Type>(elems.size(), fp16x2Ty));
Value res = getStructFromElements(loc, elems, rewriter, resTy); Value res = getStructFromElements(loc, elems, rewriter, resTy);
return res; return res;
} }

View File

@@ -14,7 +14,7 @@ using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr;
// Contains some helper functions for both Load and Store conversions. // Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { struct LoadStoreConversionBase {
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
: axisAnalysisPass(axisAnalysisPass) {} : axisAnalysisPass(axisAnalysisPass) {}
@@ -639,9 +639,8 @@ struct InsertSliceOpConversion
auto smemBase = gep(elemPtrTy, smemObj.base, offset); auto smemBase = gep(elemPtrTy, smemObj.base, offset);
auto llSrc = adaptor.source(); auto llSrc = adaptor.source();
auto srcIndices = auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
emitBaseIndexForBlockedLayout(loc, rewriter, srcLayout, srcShape); storeDistributedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
elemTy, loc, rewriter); elemTy, loc, rewriter);
// Barrier is not necessary. // Barrier is not necessary.
// The membar pass knows that it writes to shared memory and will handle it // The membar pass knows that it writes to shared memory and will handle it
@@ -657,12 +656,12 @@ struct InsertSliceAsyncOpConversion
using ConvertTritonGPUOpToLLVMPattern< using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern; triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncOpConversion(LLVMTypeConverter &converter, InsertSliceAsyncOpConversion(
const Allocation *allocation, Value smem, LLVMTypeConverter &converter, const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>( : ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
converter, allocation, smem, benefit), converter, allocation, smem, indexCacheInfo, benefit),
LoadStoreConversionBase(axisAnalysisPass) {} LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult LogicalResult
@@ -865,11 +864,11 @@ struct InsertSliceAsyncOpConversion
} }
}; };
void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateLoadStoreOpToLLVMPatterns(
RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) { PatternBenefit benefit) {
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit); patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit); patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
@@ -878,7 +877,8 @@ void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem, patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit); axisInfoAnalysis, benefit);
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem, patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
benefit); indexCacheInfo, benefit);
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem, patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit); indexCacheInfo, axisInfoAnalysis,
benefit);
} }

View File

@@ -6,11 +6,11 @@
using namespace mlir; using namespace mlir;
using namespace mlir::triton; using namespace mlir::triton;
void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateLoadStoreOpToLLVMPatterns(
RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit); PatternBenefit benefit);
#endif #endif

View File

@@ -164,7 +164,7 @@ private:
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
SmallVector<SmallVector<unsigned>> offset = SmallVector<SmallVector<unsigned>> offset =
emitOffsetForBlockedLayout(srcLayout, srcShape); emitOffsetForLayout(srcLayout, srcShape);
std::map<SmallVector<unsigned>, Value> accs; std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices; std::map<SmallVector<unsigned>, Value> accIndices;
@@ -479,10 +479,12 @@ private:
} }
}; };
void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateReduceOpToLLVMPatterns(
RewritePatternSet &patterns, int numWarps, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) { PatternBenefit benefit) {
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit); patterns.add<ReduceOpConversion>(typeConverter, allocation, smem,
indexCacheInfo, benefit);
} }

View File

@@ -6,10 +6,11 @@
using namespace mlir; using namespace mlir;
using namespace mlir::triton; using namespace mlir::triton;
void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateReduceOpToLLVMPatterns(
RewritePatternSet &patterns, int numWarps, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit); PatternBenefit benefit);
#endif #endif

View File

@@ -63,6 +63,7 @@ struct BroadcastOpConversion
auto srcShape = srcTy.getShape(); auto srcShape = srcTy.getShape();
auto resultShape = resultTy.getShape(); auto resultShape = resultTy.getShape();
unsigned rank = srcTy.getRank(); unsigned rank = srcTy.getRank();
assert(rank == resultTy.getRank()); assert(rank == resultTy.getRank());
auto order = triton::gpu::getOrder(srcLayout); auto order = triton::gpu::getOrder(srcLayout);
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape); auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
@@ -272,8 +273,12 @@ struct PrintfOpConversion
struct MakeRangeOpConversion struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> { : public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) MakeRangeOpConversion(
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(converter, LLVMTypeConverter &converter,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(
converter, /*Allocation*/ nullptr, Value{}, indexCacheInfo,
benefit) {} benefit) {}
LogicalResult LogicalResult
@@ -500,10 +505,11 @@ struct AsyncWaitOpConversion
} }
}; };
void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateTritonGPUToLLVMPatterns(
RewritePatternSet &patterns, int numWarps, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) { PatternBenefit benefit) {
patterns.add<AddPtrOpConversion>(typeConverter, benefit); patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem, patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
@@ -515,7 +521,7 @@ void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit); benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit); patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit); patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit); patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit); patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit); patterns.add<PrintfOpConversion>(typeConverter, benefit);
} }

View File

@@ -6,10 +6,11 @@
using namespace mlir; using namespace mlir;
using namespace mlir::triton; using namespace mlir::triton;
void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, void populateTritonGPUToLLVMPatterns(
RewritePatternSet &patterns, int numWarps, mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
AxisInfoAnalysis &axisInfoAnalysis, int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit); PatternBenefit benefit);
#endif #endif

View File

@@ -18,7 +18,6 @@ using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr;
// FuncOpConversion/FuncOpConversionBase is borrowed from // FuncOpConversion/FuncOpConversionBase is borrowed from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276 // https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
// since it is not exposed on header files in mlir v14 // since it is not exposed on header files in mlir v14
@@ -128,7 +127,60 @@ protected:
} }
}; };
struct ConvertTritonGPUOpToLLVMPatternBase { using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
struct CacheKeyDenseMapInfo {
static IndexCacheKeyT getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return std::make_pair(
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
SmallVector<int64_t>{});
}
static IndexCacheKeyT getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return std::make_pair(
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
SmallVector<int64_t>{std::numeric_limits<int64_t>::max()});
}
static unsigned getHashValue(IndexCacheKeyT key) {
return llvm::hash_combine(
mlir::hash_value(key.first),
llvm::hash_combine_range(key.second.begin(), key.second.end()));
}
static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) {
return LHS == RHS;
}
};
class ConvertTritonGPUOpToLLVMPatternBase {
public:
// Two levels of value cache in emitting indices calculation:
// Key: pair<layout, shape>
struct IndexCacheInfo {
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
*baseIndexCache;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
CacheKeyDenseMapInfo> *indexCache;
OpBuilder::InsertPoint *indexInsertPoint;
};
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter)
: converter(&typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem)
: converter(&typeConverter), allocation(allocation), smem(smem) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
IndexCacheInfo indexCacheInfo)
: converter(&typeConverter), indexCacheInfo(indexCacheInfo),
allocation(allocation), smem(smem) {}
LLVMTypeConverter *getTypeConverter() const { return converter; }
static Value static Value
getStructFromSharedMemoryObject(Location loc, getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj, const SharedMemoryObject &smemObj,
@@ -139,25 +191,6 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
return getStructFromElements(loc, elems, rewriter, structTy); return getStructFromElements(loc, elems, rewriter, structTy);
} }
};
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp>,
public ConvertTritonGPUOpToLLVMPatternBase {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
allocation(allocation), smem(smem) {}
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const { Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType(); auto llvmIndexTy = this->getTypeConverter()->getIndexType();
@@ -169,6 +202,23 @@ public:
return threadId; return threadId;
} }
// -----------------------------------------------------------------------
// Shared memory utilities
// -----------------------------------------------------------------------
template <typename T>
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
T value) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
auto bufferId = allocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
Value offVal = idx_val(offset);
Value base = gep(ptrTy, smem, offVal);
return base;
}
// ----------------------------------------------------------------------- // -----------------------------------------------------------------------
// Utilities // Utilities
// ----------------------------------------------------------------------- // -----------------------------------------------------------------------
@@ -242,6 +292,116 @@ public:
return ret; return ret;
} }
struct SmallVectorKeyInfo {
static unsigned getHashValue(const SmallVector<unsigned> &key) {
return llvm::hash_combine_range(key.begin(), key.end());
}
static bool isEqual(const SmallVector<unsigned> &lhs,
const SmallVector<unsigned> &rhs) {
return lhs == rhs;
}
static SmallVector<unsigned> getEmptyKey() {
return SmallVector<unsigned>();
}
static SmallVector<unsigned> getTombstoneKey() {
return {std::numeric_limits<unsigned>::max()};
}
};
// -----------------------------------------------------------------------
// Get offsets / indices for any layout
// -----------------------------------------------------------------------
SmallVector<Value> emitBaseIndexForLayout(Location loc,
ConversionPatternRewriter &rewriter,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
IndexCacheKeyT key = std::make_pair(layout, llvm::to_vector(shape));
auto cache = indexCacheInfo.baseIndexCache;
assert(cache && "baseIndexCache is nullptr");
auto insertPt = indexCacheInfo.indexInsertPoint;
if (cache->count(key) > 0) {
return cache->lookup(key);
} else {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
restoreInsertionPointIfSet(insertPt, rewriter);
SmallVector<Value> result;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
result =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
if (mmaLayout.isAmpere())
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
cache->insert(std::make_pair(key, result));
*insertPt = rewriter.saveInsertionPoint();
return result;
}
}
SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitOffsetForBlockedLayout(blockedLayout, shape);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
}
llvm_unreachable("unsupported emitOffsetForLayout");
}
// -----------------------------------------------------------------------
// Emit indices
// -----------------------------------------------------------------------
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
IndexCacheKeyT key(layout, llvm::to_vector(shape));
auto cache = indexCacheInfo.indexCache;
assert(cache && "indexCache is nullptr");
auto insertPt = indexCacheInfo.indexInsertPoint;
if (cache->count(key) > 0) {
return cache->lookup(key);
} else {
ConversionPatternRewriter::InsertionGuard guard(b);
restoreInsertionPointIfSet(insertPt, b);
SmallVector<SmallVector<Value>> result;
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, blocked, shape);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, mma, shape);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
result = emitIndicesForSliceLayout(loc, b, slice, shape);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
"implemented yet");
}
cache->insert(std::make_pair(key, result));
*insertPt = b.saveInsertionPoint();
return result;
}
}
private:
void restoreInsertionPointIfSet(OpBuilder::InsertPoint *insertPt,
ConversionPatternRewriter &rewriter) const {
if (insertPt->isSet()) {
rewriter.restoreInsertionPoint(*insertPt);
} else {
auto func =
rewriter.getInsertionPoint()->getParentOfType<LLVM::LLVMFuncOp>();
rewriter.setInsertionPointToStart(&func.getBody().front());
}
}
// ----------------------------------------------------------------------- // -----------------------------------------------------------------------
// Blocked layout indices // Blocked layout indices
// ----------------------------------------------------------------------- // -----------------------------------------------------------------------
@@ -411,38 +571,6 @@ public:
return ret; return ret;
} }
// -----------------------------------------------------------------------
// Get offsets / indices for any layout
// -----------------------------------------------------------------------
SmallVector<Value> emitBaseIndexForLayout(Location loc,
ConversionPatternRewriter &rewriter,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
if (mmaLayout.isAmpere())
return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
}
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitOffsetForBlockedLayout(blockedLayout, shape);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
}
llvm_unreachable("unsupported emitOffsetForLayout");
}
// Emit indices calculation within each ConversionPattern, and returns a // Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix. // [elemsPerThread X rank] index matrix.
@@ -470,22 +598,6 @@ public:
return multiDimIdx; return multiDimIdx;
} }
struct SmallVectorKeyInfo {
static unsigned getHashValue(const SmallVector<unsigned> &key) {
return llvm::hash_combine_range(key.begin(), key.end());
}
static bool isEqual(const SmallVector<unsigned> &lhs,
const SmallVector<unsigned> &rhs) {
return lhs == rhs;
}
static SmallVector<unsigned> getEmptyKey() {
return SmallVector<unsigned>();
}
static SmallVector<unsigned> getTombstoneKey() {
return {std::numeric_limits<unsigned>::max()};
}
};
SmallVector<SmallVector<Value>> SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter, emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
const SliceEncodingAttr &sliceLayout, const SliceEncodingAttr &sliceLayout,
@@ -505,46 +617,45 @@ public:
return resultIndices; return resultIndices;
} }
// -----------------------------------------------------------------------
// Emit indices
// -----------------------------------------------------------------------
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
return emitIndicesForDistributedLayout(loc, b, blocked, shape);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
return emitIndicesForDistributedLayout(loc, b, mma, shape);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
return emitIndicesForSliceLayout(loc, b, slice, shape);
} else {
assert(0 && "emitIndices for layouts other than blocked & slice not "
"implemented yet");
return {};
}
}
// -----------------------------------------------------------------------
// Shared memory utilities
// -----------------------------------------------------------------------
template <typename T>
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
T value) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
auto bufferId = allocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
Value offVal = idx_val(offset);
Value base = gep(ptrTy, smem, offVal);
return base;
}
protected: protected:
LLVMTypeConverter *converter;
const Allocation *allocation; const Allocation *allocation;
Value smem; Value smem;
IndexCacheInfo indexCacheInfo;
};
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp>,
public ConvertTritonGPUOpToLLVMPatternBase {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {}
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
IndexCacheInfo indexCacheInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem,
indexCacheInfo) {}
protected:
LLVMTypeConverter *getTypeConverter() const {
return ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter();
}
}; };
#endif #endif

View File

@@ -170,16 +170,20 @@ public:
// We set a higher benefit here to ensure triton's patterns runs before // We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community // arith patterns for some encoding not supported by the community
// patterns. // patterns.
OpBuilder::InsertPoint indexInsertPoint;
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
&baseIndexCache, &indexCache, &indexInsertPoint};
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
// Normal conversions // Normal conversions
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps, populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
/*benefit=*/10); indexCacheInfo, /*benefit=*/10);
// ConvertLayoutOp // ConvertLayoutOp
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps, populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
/*benefit=*/10); indexCacheInfo, /*benefit=*/10);
// DotOp // DotOp
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
@@ -191,11 +195,11 @@ public:
// LoadStoreOp // LoadStoreOp
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps, populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
/*benefit=*/10); indexCacheInfo, /*benefit=*/10);
// ReduceOp // ReduceOp
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps, populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
/*benefit=*/10); indexCacheInfo, /*benefit=*/10);
// ViewOp // ViewOp
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps, populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem, axisInfoAnalysis, &allocation, smem,
@@ -215,6 +219,13 @@ public:
private: private:
Value smem; Value smem;
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
baseIndexCache;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
CacheKeyDenseMapInfo>
indexCache;
int computeCapability{}; int computeCapability{};
void initSharedMemory(size_t size, void initSharedMemory(size_t size,

View File

@@ -4,11 +4,13 @@
using namespace mlir; using namespace mlir;
using namespace mlir::triton; using namespace mlir::triton;
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper; using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::DotOpMmaV2ConversionHelper; using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct; using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements; using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::getElemsPerThread;
struct SplatOpConversion struct SplatOpConversion
@@ -38,6 +40,11 @@ struct SplatOpConversion
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
return getStructFromElements(loc, elems, rewriter, structTy); return getStructFromElements(loc, elems, rewriter, structTy);
} else if (auto dotLayout =
tensorTy.getEncoding()
.dyn_cast<triton::gpu::DotOperandEncodingAttr>()) {
return convertSplatLikeOpWithDotOperandLayout(
dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
} else if (auto mmaLayout = } else if (auto mmaLayout =
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) { tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
return convertSplatLikeOpWithMmaLayout( return convertSplatLikeOpWithMmaLayout(
@@ -48,6 +55,38 @@ struct SplatOpConversion
return {}; return {};
} }
static Value convertSplatLikeOpWithDotOperandLayout(
const triton::gpu::DotOperandEncodingAttr &layout, Type resType,
Type elemType, Value constVal, TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto shape = tensorTy.getShape();
auto parent = layout.getParent();
int numElems{};
if (auto mmaLayout = parent.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
numElems = layout.getOpIdx() == 0
? MMA16816ConversionHelper::getANumElemsPerThread(
tensorTy, mmaLayout.getWarpsPerCTA()[0])
: MMA16816ConversionHelper::getBNumElemsPerThread(
tensorTy, mmaLayout.getWarpsPerCTA()[1]);
} else if (mmaLayout.isVolta()) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
numElems = layout.getOpIdx() == 0
? helper.numElemsPerThreadA(shape, {0, 1})
: helper.numElemsPerThreadB(shape, {0, 1});
}
} else if (auto blockedLayout = parent.dyn_cast<BlockedEncodingAttr>()) {
numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout);
} else {
assert(false && "Unsupported layout found");
}
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(numElems, elemType));
return getStructFromElements(loc, SmallVector<Value>(numElems, constVal),
rewriter, structTy);
}
static Value convertSplatLikeOpWithMmaLayout( static Value convertSplatLikeOpWithMmaLayout(
const MmaEncodingAttr &layout, Type resType, Type elemType, const MmaEncodingAttr &layout, Type resType, Type elemType,
Value constVal, TypeConverter *typeConverter, Value constVal, TypeConverter *typeConverter,

View File

@@ -257,6 +257,11 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
} }
}; };
bool isaDistributedLayout(const Attribute &layout) {
return layout.isa<BlockedEncodingAttr>() || layout.isa<MmaEncodingAttr>() ||
layout.isa<SliceEncodingAttr>();
}
} // namespace gpu } // namespace gpu
} // namespace triton } // namespace triton
} // namespace mlir } // namespace mlir

View File

@@ -23,6 +23,10 @@
using namespace mlir; using namespace mlir;
namespace { namespace {
#include "TritonGPUCombine.inc" #include "TritonGPUCombine.inc"
using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// //
@@ -1020,6 +1024,7 @@ public:
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue(); dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row)) if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure(); return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(), op->getContext(), dstDotOperandLayout.getOpIdx(),
@@ -1061,7 +1066,8 @@ public:
auto dotOp = cast<triton::DotOp>(op); auto dotOp = cast<triton::DotOp>(op);
// TODO: Check data-types and SM compatibility // TODO: Check data-types and SM compatibility
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>(); auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>()) if (!oldRetType.getEncoding() ||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure(); return failure();
auto AType = dotOp.getOperand(0).getType().cast<RankedTensorType>(); auto AType = dotOp.getOperand(0).getType().cast<RankedTensorType>();
@@ -1171,7 +1177,8 @@ public:
for (size_t i = 0; i < newInitArgs.size(); i++) { for (size_t i = 0; i < newInitArgs.size(); i++) {
auto initArg = newInitArgs[i]; auto initArg = newInitArgs[i];
auto regionArg = forOp.getRegionIterArgs()[i]; auto regionArg = forOp.getRegionIterArgs()[i];
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) { if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType() ||
newInitArgs[i].getType() != forOp.getResultTypes()[i]) {
shouldRematerialize = true; shouldRematerialize = true;
break; break;
} }
@@ -1187,15 +1194,207 @@ public:
BlockAndValueMapping mapping; BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (Operation &op : forOp.getBody()->getOperations()) { for (Operation &op : forOp.getBody()->getOperations()) {
Operation *newOp = rewriter.clone(op, mapping); rewriter.clone(op, mapping);
} }
rewriter.replaceOp(forOp, newForOp.getResults()); rewriter.replaceOp(forOp, newForOp.getResults());
return success(); return success();
} }
}; };
// This pattern collects the wrong Mma those need to update and create the right
// ones for each.
class CollectMmaToUpdateForVolta : public mlir::RewritePattern {
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
public:
CollectMmaToUpdateForVolta(
mlir::MLIRContext *ctx,
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 1, ctx),
mmaToUpdate(mmaToUpdate) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<triton::DotOp>(op);
auto *ctx = dotOp->getContext();
auto AT = dotOp.a().getType().cast<RankedTensorType>();
auto BT = dotOp.b().getType().cast<RankedTensorType>();
auto DT = dotOp.d().getType().cast<RankedTensorType>();
if (!DT.getEncoding())
return failure();
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
if (!(mmaLayout && mmaLayout.isVolta()))
return failure();
// Has processed.
if (mmaToUpdate.count(mmaLayout))
return failure();
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow_, isBRow_, isAVec4, isBVec4] =
mmaLayout.decodeVoltaLayoutStates();
if (isARow_ == isARow && isBRow_ == isBRow) {
return failure(); // No need to update
}
auto newMmaLayout = MmaEncodingAttr::get(
ctx, mmaLayout.getVersionMajor(), mmaLayout.getWarpsPerCTA(),
AT.getShape(), BT.getShape(), isARow, isBRow);
// Collect the wrong MMA Layouts, and mark need to update.
mmaToUpdate.try_emplace(mmaLayout, newMmaLayout);
return failure();
}
};
// Correct the versionMinor field in MmaEncodingAttr for Volta.
class UpdateMMAVersionMinorForVolta : public mlir::RewritePattern {
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
enum class Kind {
kUnk,
kCvtToMma,
kCvtToDotOp,
kDot,
kConstant,
};
mutable Kind rewriteKind{Kind::kUnk};
public:
UpdateMMAVersionMinorForVolta(
mlir::MLIRContext *ctx, llvm::StringRef opName,
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
: RewritePattern(opName, 1 /*benefit*/, ctx), mmaToUpdate(mmaToUpdate) {}
LogicalResult match(Operation *op) const override {
MmaEncodingAttr mma;
if (mmaToUpdate.empty())
return failure();
if (op->getNumResults() != 1)
return failure();
auto tensorTy = op->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return failure();
// ConvertLayoutOp
if (auto cvt = llvm::dyn_cast<ConvertLayoutOp>(op)) {
// cvt X -> dot_operand
if (auto dotOperand =
tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>()) {
mma = dotOperand.getParent().dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kCvtToDotOp;
if (mma && mmaToUpdate.count(mma))
return success();
}
if ((mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>())) {
// cvt X -> mma
rewriteKind = Kind::kCvtToMma;
if (mma && mmaToUpdate.count(mma))
return success();
}
} else if (auto dot = llvm::dyn_cast<DotOp>(op)) {
// DotOp
mma = dot.d()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kDot;
} else if (auto constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
// ConstantOp
mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kConstant;
}
return success(mma && mmaToUpdate.count(mma));
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
switch (rewriteKind) {
case Kind::kDot:
rewriteDot(op, rewriter);
break;
case Kind::kConstant:
rewriteConstant(op, rewriter);
break;
case Kind::kCvtToDotOp:
rewriteCvtDotOp(op, rewriter);
break;
case Kind::kCvtToMma:
rewriteCvtToMma(op, rewriter);
break;
default:
llvm::report_fatal_error("Not supported rewrite kind");
}
}
private:
void rewriteCvtDotOp(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto cvt = llvm::cast<ConvertLayoutOp>(op);
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
auto dotOperand = tensorTy.getEncoding().cast<DotOperandEncodingAttr>();
MmaEncodingAttr newMma =
mmaToUpdate.lookup(dotOperand.getParent().cast<MmaEncodingAttr>());
auto newDotOperand = DotOperandEncodingAttr::get(
ctx, dotOperand.getOpIdx(), newMma, dotOperand.getIsMMAv1Row());
auto newTensorTy = RankedTensorType::get(
tensorTy.getShape(), tensorTy.getElementType(), newDotOperand);
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
cvt.getOperand());
}
void rewriteDot(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto dot = llvm::cast<DotOp>(op);
auto tensorTy = dot.d().getType().cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
rewriter.replaceOpWithNewOp<DotOp>(op, newTensorTy, dot.a(), dot.b(),
dot.c(), dot.allowTF32());
}
void rewriteCvtToMma(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto cvt = llvm::cast<ConvertLayoutOp>(op);
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
cvt.getOperand());
}
void rewriteConstant(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto constant = llvm::cast<arith::ConstantOp>(op);
auto tensorTy = constant.getResult().getType().dyn_cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
if (auto attr = constant.getValue().dyn_cast<SplatElementsAttr>()) {
auto newRet =
SplatElementsAttr::get(newTensorTy, attr.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newTensorTy, newRet);
return;
}
assert(false && "Not supported ConstantOp value type");
}
};
} // namespace } // namespace
#define GEN_PASS_CLASSES #define GEN_PASS_CLASSES
@@ -1230,6 +1429,28 @@ public:
signalPassFailure(); signalPassFailure();
} }
llvm::DenseMap<MmaEncodingAttr, MmaEncodingAttr> mmaToUpdate;
{
mlir::RewritePatternSet patterns(context);
patterns.add<CollectMmaToUpdateForVolta>(context, mmaToUpdate);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
{
mlir::RewritePatternSet patterns(context);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, DotOp::getOperationName(), mmaToUpdate);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, ConvertLayoutOp::getOperationName(), mmaToUpdate);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, arith::ConstantOp::getOperationName(), mmaToUpdate);
mlir::GreedyRewriteConfig config;
config.useTopDownTraversal = true;
if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed())
signalPassFailure();
}
mlir::RewritePatternSet loopFixup(context); mlir::RewritePatternSet loopFixup(context);
loopFixup.add<FixupLoop>(context); loopFixup.add<FixupLoop>(context);
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) { if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {

View File

@@ -141,10 +141,10 @@ class CMakeBuild(build_ext):
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DTRITON_BUILD_TUTORIALS=OFF", "-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON", "-DTRITON_BUILD_PYTHON_MODULE=ON",
# '-DPYTHON_EXECUTABLE=' + sys.executable, "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', "-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON",
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir, "-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DLLVM_EXTERNAL_LIT=" + lit_dir "-DLLVM_EXTERNAL_LIT=" + lit_dir,
] + thirdparty_cmake_args ] + thirdparty_cmake_args
# configuration # configuration

View File

@@ -491,10 +491,9 @@ def make_ptr_str(name, shape):
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` # TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
@pytest.mark.parametrize("expr, dtype_str", [ @pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', d) (f'x[{s}]', d)
for s in ['None, :', ':, None'] for s in ['None, :', ':, None',
# FIXME: 3d indexing doesn't work 'None, :, :',
#'None, :, :', ':, :, None']
# ':, :, None']
for d in ['int32', 'uint32', 'uint16'] for d in ['int32', 'uint32', 'uint16']
]) ])
def test_index1d(expr, dtype_str, device='cuda'): def test_index1d(expr, dtype_str, device='cuda'):
@@ -1228,20 +1227,20 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
elif dtype == 'int8': elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
# FIXME: Unsupported layout found in ConvertSplatLikeOp
# def test_dot_without_load(): def test_dot_without_load():
# @triton.jit @triton.jit
# def kernel(out): def kernel(out):
# pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
# a = tl.zeros((32, 32), tl.float32) a = tl.zeros((32, 32), tl.float32)
# b = tl.zeros((32, 32), tl.float32) b = tl.zeros((32, 32), tl.float32)
# c = tl.zeros((32, 32), tl.float32) c = tl.zeros((32, 32), tl.float32)
# c = tl.dot(a, b) c = tl.dot(a, b)
# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
# tl.store(pout, c) tl.store(pout, c)
#
# out = torch.ones((32, 32), dtype=torch.float32, device="cuda") out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
# kernel[(1,)](out) kernel[(1,)](out)
# --------------- # ---------------
# test arange # test arange

View File

@@ -20,6 +20,8 @@ from .core import (
atomic_xor, atomic_xor,
bfloat16, bfloat16,
block_type, block_type,
broadcast,
broadcast_to,
cat, cat,
cdiv, cdiv,
constexpr, constexpr,
@@ -105,6 +107,8 @@ __all__ = [
"atomic_xor", "atomic_xor",
"bfloat16", "bfloat16",
"block_type", "block_type",
"broadcast",
"broadcast_to",
"builtin", "builtin",
"cat", "cat",
"cdiv", "cdiv",

View File

@@ -596,11 +596,9 @@ class tensor:
if isinstance(slices, slice): if isinstance(slices, slice):
slices = [slices] slices = [slices]
ret = self ret = self
n_inserted = 0
for dim, sl in enumerate(slices): for dim, sl in enumerate(slices):
if isinstance(sl, constexpr) and sl.value is None: if isinstance(sl, constexpr) and sl.value is None:
ret = semantic.expand_dims(ret, dim + n_inserted, _builder) ret = semantic.expand_dims(ret, dim, _builder)
n_inserted += 1
elif sl == slice(None, None, None): elif sl == slice(None, None, None):
pass pass
else: else:

View File

@@ -997,8 +997,7 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// ----- // -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.nctaid.x // CHECK: nvvm.read.ptx.sreg.nctaid.x
// CHECK: nvvm.read.ptx.sreg.nctaid.y // CHECK: nvvm.read.ptx.sreg.nctaid.y
// CHECK: nvvm.read.ptx.sreg.nctaid.z // CHECK: nvvm.read.ptx.sreg.nctaid.z
@@ -1011,6 +1010,48 @@ func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
tt.store %a, %0 : tensor<32xi32, #blocked0> tt.store %a, %0 : tensor<32xi32, #blocked0>
return return
}
} }
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: test_index_cache
func @test_index_cache() {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_base_index_cache
func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_index_cache_different_block
func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
scf.if %arg1 {
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
}
return
}
} }

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -tritongpu-combine 2>&1 | FileCheck %s // RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
@@ -7,7 +7,6 @@
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> // CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> // CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
func @cst() -> tensor<1024xi32, #layout1> { func @cst() -> tensor<1024xi32, #layout1> {
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0> %cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> %1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
@@ -62,9 +61,9 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
// CHECK-LABEL: transpose // CHECK-LABEL: transpose
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout // CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> // CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]> // CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64xf32, [[col_layout]]>
// CHECK: return // CHECK: return
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1> %cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
@@ -184,3 +183,32 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
tt.store %21, %22 : tensor<256xf32, #layout1> tt.store %21, %22 : tensor<256xf32, #layout1>
return return
} }
// -----
// check the UpdateMMAVersionMinorForVolta pattern
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
#mma0 = #triton_gpu.mma<{versionMajor=1, versionMinor=0, warpsPerCTA=[1,1]}>
// Here, the isMMAv1Row of a and b's dot_operands mismatch #mma0's versionMinor,
// and the pattern should update the versionMinor.
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
// It creates a new MMA layout to fit with $a and $b's dot_operand
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 11, warpsPerCTA = [1, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: dot_mmav1
func @dot_mmav1(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) -> tensor<16x16xf32, #blocked0> {
%C = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked0>
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_a>
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_b>
%CC = triton_gpu.convert_layout %C : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #mma0>
// CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, %cst {allowTF32 = true} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[new_mma]], isMMAv1Row = true}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[new_mma]], isMMAv1Row = true}>> -> tensor<16x16xf32, [[new_mma]]>
%D = tt.dot %AA, %BB, %CC {allowTF32 = true} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
%res = triton_gpu.convert_layout %D : (tensor<16x16xf32, #mma0>) -> tensor<16x16xf32, #blocked0>
return %res : tensor<16x16xf32, #blocked0>
}
}