Merge remote-tracking branch 'origin/master' into phil/fused-attention-perf-fixup
This commit is contained in:
@@ -39,6 +39,8 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
|
|||||||
|
|
||||||
SmallVector<unsigned> getOrder(const Attribute &layout);
|
SmallVector<unsigned> getOrder(const Attribute &layout);
|
||||||
|
|
||||||
|
bool isaDistributedLayout(const Attribute &layout);
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -9,10 +9,12 @@ using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
|||||||
using ::mlir::LLVM::getStructFromElements;
|
using ::mlir::LLVM::getStructFromElements;
|
||||||
using ::mlir::LLVM::MMA16816ConversionHelper;
|
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||||
|
using ::mlir::triton::gpu::getContigPerThread;
|
||||||
using ::mlir::triton::gpu::getElemsPerThread;
|
using ::mlir::triton::gpu::getElemsPerThread;
|
||||||
using ::mlir::triton::gpu::getOrder;
|
using ::mlir::triton::gpu::getOrder;
|
||||||
using ::mlir::triton::gpu::getShapePerCTA;
|
using ::mlir::triton::gpu::getShapePerCTA;
|
||||||
using ::mlir::triton::gpu::getSizePerThread;
|
using ::mlir::triton::gpu::getSizePerThread;
|
||||||
|
using ::mlir::triton::gpu::isaDistributedLayout;
|
||||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||||
|
|
||||||
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
||||||
@@ -24,111 +26,63 @@ bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
|||||||
dotOperandLayout.getParent() == mmaLayout;
|
dotOperandLayout.getParent() == mmaLayout;
|
||||||
}
|
}
|
||||||
|
|
||||||
void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides,
|
void storeDistributedToShared(Value src, Value llSrc,
|
||||||
ArrayRef<Value> srcIndices, Value dst, Value smemBase,
|
ArrayRef<Value> dstStrides,
|
||||||
Type elemTy, Location loc,
|
ArrayRef<SmallVector<Value>> srcIndices,
|
||||||
ConversionPatternRewriter &rewriter) {
|
Value dst, Value smemBase, Type elemTy,
|
||||||
|
Location loc,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||||
auto srcShape = srcTy.getShape();
|
auto srcShape = srcTy.getShape();
|
||||||
assert(srcShape.size() == 2 && "Unexpected rank of insertSlice");
|
assert(srcShape.size() == 2 && "Unexpected rank of storeDistributedToShared");
|
||||||
|
|
||||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||||
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
auto srcDistributedLayout = srcTy.getEncoding();
|
||||||
|
if (auto mmaLayout = srcDistributedLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
assert((!mmaLayout.isVolta()) &&
|
||||||
|
"ConvertLayout MMAv1->Shared is not suppported yet");
|
||||||
|
}
|
||||||
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
||||||
auto inOrd = srcBlockedLayout.getOrder();
|
auto inOrd = getOrder(srcDistributedLayout);
|
||||||
auto outOrd = dstSharedLayout.getOrder();
|
auto outOrd = dstSharedLayout.getOrder();
|
||||||
if (inOrd != outOrd)
|
|
||||||
llvm_unreachable(
|
|
||||||
"blocked -> shared with different order not yet implemented");
|
|
||||||
unsigned inVec =
|
unsigned inVec =
|
||||||
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1;
|
inOrd == outOrd ? getContigPerThread(srcDistributedLayout)[inOrd[0]] : 1;
|
||||||
unsigned outVec = dstSharedLayout.getVec();
|
unsigned outVec = dstSharedLayout.getVec();
|
||||||
unsigned minVec = std::min(outVec, inVec);
|
unsigned minVec = std::min(outVec, inVec);
|
||||||
unsigned perPhase = dstSharedLayout.getPerPhase();
|
unsigned perPhase = dstSharedLayout.getPerPhase();
|
||||||
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
||||||
unsigned numElems = getElemsPerThread(srcTy);
|
unsigned numElems = getElemsPerThread(srcTy);
|
||||||
|
assert(numElems == srcIndices.size());
|
||||||
auto inVals = getElementsFromStruct(loc, llSrc, rewriter);
|
auto inVals = getElementsFromStruct(loc, llSrc, rewriter);
|
||||||
auto srcAccumSizeInThreads =
|
|
||||||
product<unsigned>(srcBlockedLayout.getSizePerThread());
|
|
||||||
auto wordTy = vec_ty(elemTy, minVec);
|
auto wordTy = vec_ty(elemTy, minVec);
|
||||||
auto elemPtrTy = ptr_ty(elemTy);
|
auto elemPtrTy = ptr_ty(elemTy);
|
||||||
|
|
||||||
// TODO: [goostavz] We should make a cache for the calculation of
|
|
||||||
// emitBaseIndexForBlockedLayout in case backend compiler not being able to
|
|
||||||
// optimize that
|
|
||||||
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
|
|
||||||
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
|
|
||||||
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
|
|
||||||
|
|
||||||
// Visit each input value in the order they are placed in inVals
|
|
||||||
//
|
|
||||||
// Please note that the order was not awaring of blockLayout.getOrder(),
|
|
||||||
// thus the adjacent elems may not belong to a same word. This could be
|
|
||||||
// improved if we update the elements order by emitIndicesForBlockedLayout()
|
|
||||||
SmallVector<unsigned> wordsInEachRep(2);
|
|
||||||
wordsInEachRep[0] = inOrd[0] == 0
|
|
||||||
? srcBlockedLayout.getSizePerThread()[0] / minVec
|
|
||||||
: srcBlockedLayout.getSizePerThread()[0];
|
|
||||||
wordsInEachRep[1] = inOrd[0] == 0
|
|
||||||
? srcBlockedLayout.getSizePerThread()[1]
|
|
||||||
: srcBlockedLayout.getSizePerThread()[1] / minVec;
|
|
||||||
Value outVecVal = i32_val(outVec);
|
Value outVecVal = i32_val(outVec);
|
||||||
Value minVecVal = i32_val(minVec);
|
Value minVecVal = i32_val(minVec);
|
||||||
auto numWordsEachRep = product<unsigned>(wordsInEachRep);
|
Value word;
|
||||||
SmallVector<Value> wordVecs(numWordsEachRep);
|
|
||||||
for (unsigned i = 0; i < numElems; ++i) {
|
for (unsigned i = 0; i < numElems; ++i) {
|
||||||
if (i % srcAccumSizeInThreads == 0) {
|
if (i % minVec == 0)
|
||||||
// start of a replication
|
word = undef(wordTy);
|
||||||
for (unsigned w = 0; w < numWordsEachRep; ++w) {
|
word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec));
|
||||||
wordVecs[w] = undef(wordTy);
|
if (i % minVec == minVec - 1) {
|
||||||
}
|
// step 1: recover the multidim_index from the index of
|
||||||
}
|
SmallVector<Value> multiDimIdx = srcIndices[i];
|
||||||
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
|
SmallVector<Value> dbgVal = srcIndices[i];
|
||||||
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
|
|
||||||
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd);
|
|
||||||
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
|
|
||||||
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
|
|
||||||
auto wordVecIdx =
|
|
||||||
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep, inOrd);
|
|
||||||
wordVecs[wordVecIdx] =
|
|
||||||
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], i32_val(pos));
|
|
||||||
|
|
||||||
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
|
// step 2: do swizzling
|
||||||
// end of replication, store the vectors into shared memory
|
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
|
||||||
unsigned linearRepIdx = i / srcAccumSizeInThreads;
|
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
|
||||||
auto multiDimRepIdx =
|
Value off_1 = mul(multiDimIdx[outOrd[1]], dstStrides[outOrd[1]]);
|
||||||
getMultiDimIndex<unsigned>(linearRepIdx, reps, inOrd);
|
Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase));
|
||||||
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
|
phaseId = urem(phaseId, i32_val(maxPhase));
|
||||||
++linearWordIdx) {
|
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
|
||||||
// step 1: recover the multidim_index from the index of
|
off_0 = mul(off_0, outVecVal);
|
||||||
// input_elements
|
remained = udiv(remained, minVecVal);
|
||||||
auto multiDimWordIdx =
|
off_0 = add(off_0, mul(remained, minVecVal));
|
||||||
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep, inOrd);
|
Value offset = add(off_1, mul(off_0, dstStrides[outOrd[0]]));
|
||||||
SmallVector<Value> multiDimIdx(2);
|
|
||||||
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
|
|
||||||
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
|
|
||||||
auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] +
|
|
||||||
multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1);
|
|
||||||
multiDimIdx[0] = add(srcIndices[0], i32_val(wordOffset0));
|
|
||||||
multiDimIdx[1] = add(srcIndices[1], i32_val(wordOffset1));
|
|
||||||
|
|
||||||
// step 2: do swizzling
|
// step 3: store
|
||||||
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
|
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
||||||
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
|
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
|
||||||
Value off_1 = mul(multiDimIdx[outOrd[1]], srcStrides[outOrd[1]]);
|
store(word, smemAddr);
|
||||||
Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase));
|
|
||||||
phaseId = urem(phaseId, i32_val(maxPhase));
|
|
||||||
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
|
|
||||||
off_0 = mul(off_0, outVecVal);
|
|
||||||
remained = udiv(remained, minVecVal);
|
|
||||||
off_0 = add(off_0, mul(remained, minVecVal));
|
|
||||||
Value offset = add(off_1, off_0);
|
|
||||||
|
|
||||||
// step 3: store
|
|
||||||
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
|
||||||
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
|
|
||||||
store(wordVecs[linearWordIdx], smemAddr);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -148,20 +102,15 @@ public:
|
|||||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||||
Attribute srcLayout = srcTy.getEncoding();
|
Attribute srcLayout = srcTy.getEncoding();
|
||||||
Attribute dstLayout = dstTy.getEncoding();
|
Attribute dstLayout = dstTy.getEncoding();
|
||||||
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
if (isaDistributedLayout(srcLayout) &&
|
||||||
dstLayout.isa<SharedEncodingAttr>()) {
|
dstLayout.isa<SharedEncodingAttr>()) {
|
||||||
return lowerBlockedToShared(op, adaptor, rewriter);
|
return lowerDistributedToShared(op, adaptor, rewriter);
|
||||||
}
|
}
|
||||||
if (srcLayout.isa<SharedEncodingAttr>() &&
|
if (srcLayout.isa<SharedEncodingAttr>() &&
|
||||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||||
return lowerSharedToDotOperand(op, adaptor, rewriter);
|
return lowerSharedToDotOperand(op, adaptor, rewriter);
|
||||||
}
|
}
|
||||||
if ((srcLayout.isa<BlockedEncodingAttr>() ||
|
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
|
||||||
srcLayout.isa<MmaEncodingAttr>() ||
|
|
||||||
srcLayout.isa<SliceEncodingAttr>()) &&
|
|
||||||
(dstLayout.isa<BlockedEncodingAttr>() ||
|
|
||||||
dstLayout.isa<MmaEncodingAttr>() ||
|
|
||||||
dstLayout.isa<SliceEncodingAttr>())) {
|
|
||||||
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
||||||
}
|
}
|
||||||
if (srcLayout.isa<MmaEncodingAttr>() &&
|
if (srcLayout.isa<MmaEncodingAttr>() &&
|
||||||
@@ -182,7 +131,7 @@ private:
|
|||||||
unsigned rank = shape.size();
|
unsigned rank = shape.size();
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
auto multiDimOffsetFirstElem =
|
auto multiDimOffsetFirstElem =
|
||||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape);
|
||||||
SmallVector<Value> multiDimOffset(rank);
|
SmallVector<Value> multiDimOffset(rank);
|
||||||
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||||
elemId, getSizePerThread(layout), getOrder(layout));
|
elemId, getSizePerThread(layout), getOrder(layout));
|
||||||
@@ -479,8 +428,8 @@ private:
|
|||||||
// Swizzling in shared memory to avoid bank conflict. Normally used for
|
// Swizzling in shared memory to avoid bank conflict. Normally used for
|
||||||
// A/B operands of dots.
|
// A/B operands of dots.
|
||||||
LogicalResult
|
LogicalResult
|
||||||
lowerBlockedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
lowerDistributedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
Value src = op.src();
|
Value src = op.src();
|
||||||
Value dst = op.result();
|
Value dst = op.result();
|
||||||
@@ -490,22 +439,20 @@ private:
|
|||||||
auto dstShape = dstTy.getShape();
|
auto dstShape = dstTy.getShape();
|
||||||
assert(srcShape.size() == 2 &&
|
assert(srcShape.size() == 2 &&
|
||||||
"Unexpected rank of ConvertLayout(blocked->shared)");
|
"Unexpected rank of ConvertLayout(blocked->shared)");
|
||||||
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
auto srcLayout = srcTy.getEncoding();
|
||||||
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
||||||
auto inOrd = srcBlockedLayout.getOrder();
|
auto inOrd = getOrder(srcLayout);
|
||||||
auto outOrd = dstSharedLayout.getOrder();
|
auto outOrd = dstSharedLayout.getOrder();
|
||||||
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
||||||
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||||
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
||||||
smemBase = bitcast(smemBase, elemPtrTy);
|
smemBase = bitcast(smemBase, elemPtrTy);
|
||||||
|
|
||||||
auto srcStrides =
|
auto dstStrides =
|
||||||
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
|
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
|
||||||
auto srcIndices = emitBaseIndexForBlockedLayout(loc, rewriter,
|
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||||
srcBlockedLayout, srcShape);
|
storeDistributedToShared(src, adaptor.src(), dstStrides, srcIndices, dst,
|
||||||
storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst,
|
smemBase, elemTy, loc, rewriter);
|
||||||
smemBase, elemTy, loc, rewriter);
|
|
||||||
|
|
||||||
auto smemObj =
|
auto smemObj =
|
||||||
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
|
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
|
||||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||||
@@ -680,7 +627,9 @@ private:
|
|||||||
void populateConvertLayoutOpToLLVMPatterns(
|
void populateConvertLayoutOpToLLVMPatterns(
|
||||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem, PatternBenefit benefit) {
|
const Allocation *allocation, Value smem,
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
|
PatternBenefit benefit) {
|
||||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
indexCacheInfo, benefit);
|
||||||
}
|
}
|
||||||
|
@@ -11,14 +11,18 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
|||||||
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
||||||
DotOperandEncodingAttr &dotOperandLayout);
|
DotOperandEncodingAttr &dotOperandLayout);
|
||||||
|
|
||||||
void storeBlockedToShared(Value src, Value llSrc, ArrayRef<Value> srcStrides,
|
void storeDistributedToShared(Value src, Value llSrc,
|
||||||
ArrayRef<Value> srcIndices, Value dst, Value smemBase,
|
ArrayRef<Value> srcStrides,
|
||||||
Type elemPtrTy, Location loc,
|
ArrayRef<SmallVector<Value>> srcIndices,
|
||||||
ConversionPatternRewriter &rewriter);
|
Value dst, Value smemBase, Type elemPtrTy,
|
||||||
|
Location loc,
|
||||||
|
ConversionPatternRewriter &rewriter);
|
||||||
|
|
||||||
void populateConvertLayoutOpToLLVMPatterns(
|
void populateConvertLayoutOpToLLVMPatterns(
|
||||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem, PatternBenefit benefit);
|
const Allocation *allocation, Value smem,
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
|
PatternBenefit benefit);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -204,7 +204,12 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
|
offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
|
||||||
}
|
}
|
||||||
|
|
||||||
Type f16x2Ty = vec_ty(f16_ty, 2);
|
Type elemX2Ty = vec_ty(f16_ty, 2);
|
||||||
|
Type elemPtrTy = ptr_ty(f16_ty);
|
||||||
|
if (tensorTy.getElementType().isBF16()) {
|
||||||
|
elemX2Ty = vec_ty(i16_ty, 2);
|
||||||
|
elemPtrTy = ptr_ty(i16_ty);
|
||||||
|
}
|
||||||
|
|
||||||
// prepare arguments
|
// prepare arguments
|
||||||
SmallVector<Value> ptrA(numPtrA);
|
SmallVector<Value> ptrA(numPtrA);
|
||||||
@@ -213,30 +218,28 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
for (int i = 0; i < numPtrA; i++)
|
for (int i = 0; i < numPtrA; i++)
|
||||||
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
|
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
|
||||||
|
|
||||||
Type f16PtrTy = ptr_ty(f16_ty);
|
|
||||||
|
|
||||||
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
|
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
|
||||||
vals[{m, k}] = {val0, val1};
|
vals[{m, k}] = {val0, val1};
|
||||||
};
|
};
|
||||||
auto loadA = [&](int m, int k) {
|
auto loadA = [&](int m, int k) {
|
||||||
int offidx = (isARow ? k / 4 : m) % numPtrA;
|
int offidx = (isARow ? k / 4 : m) % numPtrA;
|
||||||
Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]);
|
Value thePtrA = gep(elemPtrTy, smemBase, offA[offidx]);
|
||||||
|
|
||||||
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
||||||
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
||||||
Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
|
Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
|
||||||
mul(i32_val(stepAK), strideAK));
|
mul(i32_val(stepAK), strideAK));
|
||||||
Value pa = gep(f16PtrTy, thePtrA, offset);
|
Value pa = gep(elemPtrTy, thePtrA, offset);
|
||||||
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
|
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
|
||||||
Value ha = load(bitcast(pa, aPtrTy));
|
Value ha = load(bitcast(pa, aPtrTy));
|
||||||
// record lds that needs to be moved
|
// record lds that needs to be moved
|
||||||
Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty);
|
Value ha00 = bitcast(extract_element(ha, i32_val(0)), elemX2Ty);
|
||||||
Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty);
|
Value ha01 = bitcast(extract_element(ha, i32_val(1)), elemX2Ty);
|
||||||
ld(has, m, k, ha00, ha01);
|
ld(has, m, k, ha00, ha01);
|
||||||
|
|
||||||
if (vecA > 4) {
|
if (vecA > 4) {
|
||||||
Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty);
|
Value ha10 = bitcast(extract_element(ha, i32_val(2)), elemX2Ty);
|
||||||
Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty);
|
Value ha11 = bitcast(extract_element(ha, i32_val(3)), elemX2Ty);
|
||||||
if (isARow)
|
if (isARow)
|
||||||
ld(has, m, k + 4, ha10, ha11);
|
ld(has, m, k + 4, ha10, ha11);
|
||||||
else
|
else
|
||||||
@@ -256,7 +259,7 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
elems.push_back(item.second.second);
|
elems.push_back(item.second.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
Type resTy = struct_ty(SmallVector<Type>(elems.size(), f16x2Ty));
|
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
|
||||||
Value res = getStructFromElements(loc, elems, rewriter, resTy);
|
Value res = getStructFromElements(loc, elems, rewriter, resTy);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@@ -319,8 +322,12 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
|
offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
|
||||||
}
|
}
|
||||||
|
|
||||||
Type f16PtrTy = ptr_ty(f16_ty);
|
Type elemPtrTy = ptr_ty(f16_ty);
|
||||||
Type f16x2Ty = vec_ty(f16_ty, 2);
|
Type elemX2Ty = vec_ty(f16_ty, 2);
|
||||||
|
if (tensorTy.getElementType().isBF16()) {
|
||||||
|
elemPtrTy = ptr_ty(i16_ty);
|
||||||
|
elemX2Ty = vec_ty(i16_ty, 2);
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<Value> ptrB(numPtrB);
|
SmallVector<Value> ptrB(numPtrB);
|
||||||
ValueTable hbs;
|
ValueTable hbs;
|
||||||
@@ -339,17 +346,17 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
|
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
|
||||||
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
|
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
|
||||||
mul(i32_val(stepBK), strideBK));
|
mul(i32_val(stepBK), strideBK));
|
||||||
Value pb = gep(f16PtrTy, thePtrB, offset);
|
Value pb = gep(elemPtrTy, thePtrB, offset);
|
||||||
|
|
||||||
Value hb =
|
Value hb =
|
||||||
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
||||||
// record lds that needs to be moved
|
// record lds that needs to be moved
|
||||||
Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty);
|
Value hb00 = bitcast(extract_element(hb, i32_val(0)), elemX2Ty);
|
||||||
Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty);
|
Value hb01 = bitcast(extract_element(hb, i32_val(1)), elemX2Ty);
|
||||||
ld(hbs, n, K, hb00, hb01);
|
ld(hbs, n, K, hb00, hb01);
|
||||||
if (vecB > 4) {
|
if (vecB > 4) {
|
||||||
Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty);
|
Value hb10 = bitcast(extract_element(hb, i32_val(2)), elemX2Ty);
|
||||||
Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty);
|
Value hb11 = bitcast(extract_element(hb, i32_val(3)), elemX2Ty);
|
||||||
if (isBRow)
|
if (isBRow)
|
||||||
ld(hbs, n + 1, K, hb10, hb11);
|
ld(hbs, n + 1, K, hb10, hb11);
|
||||||
else
|
else
|
||||||
@@ -369,8 +376,7 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
elems.push_back(item.second.first);
|
elems.push_back(item.second.first);
|
||||||
elems.push_back(item.second.second);
|
elems.push_back(item.second.second);
|
||||||
}
|
}
|
||||||
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
|
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
|
||||||
Type resTy = struct_ty(SmallVector<Type>(elems.size(), fp16x2Ty));
|
|
||||||
Value res = getStructFromElements(loc, elems, rewriter, resTy);
|
Value res = getStructFromElements(loc, elems, rewriter, resTy);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
@@ -14,7 +14,7 @@ using ::mlir::triton::gpu::getElemsPerThread;
|
|||||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||||
|
|
||||||
// Contains some helper functions for both Load and Store conversions.
|
// Contains some helper functions for both Load and Store conversions.
|
||||||
struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
struct LoadStoreConversionBase {
|
||||||
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||||
: axisAnalysisPass(axisAnalysisPass) {}
|
: axisAnalysisPass(axisAnalysisPass) {}
|
||||||
|
|
||||||
@@ -639,10 +639,9 @@ struct InsertSliceOpConversion
|
|||||||
auto smemBase = gep(elemPtrTy, smemObj.base, offset);
|
auto smemBase = gep(elemPtrTy, smemObj.base, offset);
|
||||||
|
|
||||||
auto llSrc = adaptor.source();
|
auto llSrc = adaptor.source();
|
||||||
auto srcIndices =
|
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||||
emitBaseIndexForBlockedLayout(loc, rewriter, srcLayout, srcShape);
|
storeDistributedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
|
||||||
storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
|
elemTy, loc, rewriter);
|
||||||
elemTy, loc, rewriter);
|
|
||||||
// Barrier is not necessary.
|
// Barrier is not necessary.
|
||||||
// The membar pass knows that it writes to shared memory and will handle it
|
// The membar pass knows that it writes to shared memory and will handle it
|
||||||
// properly.
|
// properly.
|
||||||
@@ -657,12 +656,12 @@ struct InsertSliceAsyncOpConversion
|
|||||||
using ConvertTritonGPUOpToLLVMPattern<
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
InsertSliceAsyncOpConversion(LLVMTypeConverter &converter,
|
InsertSliceAsyncOpConversion(
|
||||||
const Allocation *allocation, Value smem,
|
LLVMTypeConverter &converter, const Allocation *allocation, Value smem,
|
||||||
AxisInfoAnalysis &axisAnalysisPass,
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
PatternBenefit benefit)
|
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
|
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
|
||||||
converter, allocation, smem, benefit),
|
converter, allocation, smem, indexCacheInfo, benefit),
|
||||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
@@ -865,12 +864,12 @@ struct InsertSliceAsyncOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateLoadStoreOpToLLVMPatterns(
|
||||||
RewritePatternSet &patterns,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
int numWarps,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
const Allocation *allocation, Value smem,
|
||||||
const Allocation *allocation, Value smem,
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
PatternBenefit benefit) {
|
PatternBenefit benefit) {
|
||||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||||
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||||
patterns.add<AtomicCASOpConversion>(typeConverter, allocation, smem,
|
patterns.add<AtomicCASOpConversion>(typeConverter, allocation, smem,
|
||||||
@@ -878,7 +877,8 @@ void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
|
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
|
||||||
axisInfoAnalysis, benefit);
|
axisInfoAnalysis, benefit);
|
||||||
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
|
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
indexCacheInfo, benefit);
|
||||||
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
||||||
axisInfoAnalysis, benefit);
|
indexCacheInfo, axisInfoAnalysis,
|
||||||
|
benefit);
|
||||||
}
|
}
|
||||||
|
@@ -6,11 +6,11 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateLoadStoreOpToLLVMPatterns(
|
||||||
RewritePatternSet &patterns,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
int numWarps,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
const Allocation *allocation, Value smem,
|
||||||
const Allocation *allocation, Value smem,
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
PatternBenefit benefit);
|
PatternBenefit benefit);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -164,7 +164,7 @@ private:
|
|||||||
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||||
|
|
||||||
SmallVector<SmallVector<unsigned>> offset =
|
SmallVector<SmallVector<unsigned>> offset =
|
||||||
emitOffsetForBlockedLayout(srcLayout, srcShape);
|
emitOffsetForLayout(srcLayout, srcShape);
|
||||||
|
|
||||||
std::map<SmallVector<unsigned>, Value> accs;
|
std::map<SmallVector<unsigned>, Value> accs;
|
||||||
std::map<SmallVector<unsigned>, Value> accIndices;
|
std::map<SmallVector<unsigned>, Value> accIndices;
|
||||||
@@ -479,10 +479,12 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateReduceOpToLLVMPatterns(
|
||||||
RewritePatternSet &patterns, int numWarps,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem,
|
const Allocation *allocation, Value smem,
|
||||||
PatternBenefit benefit) {
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
PatternBenefit benefit) {
|
||||||
|
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem,
|
||||||
|
indexCacheInfo, benefit);
|
||||||
}
|
}
|
||||||
|
@@ -6,10 +6,11 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateReduceOpToLLVMPatterns(
|
||||||
RewritePatternSet &patterns, int numWarps,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem,
|
const Allocation *allocation, Value smem,
|
||||||
PatternBenefit benefit);
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
|
PatternBenefit benefit);
|
||||||
|
|
||||||
#endif
|
#endif
|
@@ -63,6 +63,7 @@ struct BroadcastOpConversion
|
|||||||
auto srcShape = srcTy.getShape();
|
auto srcShape = srcTy.getShape();
|
||||||
auto resultShape = resultTy.getShape();
|
auto resultShape = resultTy.getShape();
|
||||||
unsigned rank = srcTy.getRank();
|
unsigned rank = srcTy.getRank();
|
||||||
|
|
||||||
assert(rank == resultTy.getRank());
|
assert(rank == resultTy.getRank());
|
||||||
auto order = triton::gpu::getOrder(srcLayout);
|
auto order = triton::gpu::getOrder(srcLayout);
|
||||||
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
|
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
|
||||||
@@ -272,9 +273,13 @@ struct PrintfOpConversion
|
|||||||
struct MakeRangeOpConversion
|
struct MakeRangeOpConversion
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
||||||
|
|
||||||
MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit)
|
MakeRangeOpConversion(
|
||||||
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(converter,
|
LLVMTypeConverter &converter,
|
||||||
benefit) {}
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
|
PatternBenefit benefit)
|
||||||
|
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(
|
||||||
|
converter, /*Allocation*/ nullptr, Value{}, indexCacheInfo,
|
||||||
|
benefit) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||||
@@ -500,11 +505,12 @@ struct AsyncWaitOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateTritonGPUToLLVMPatterns(
|
||||||
RewritePatternSet &patterns, int numWarps,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem,
|
const Allocation *allocation, Value smem,
|
||||||
PatternBenefit benefit) {
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
|
PatternBenefit benefit) {
|
||||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
|
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
@@ -515,7 +521,7 @@ void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
benefit);
|
benefit);
|
||||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
|
||||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||||
}
|
}
|
@@ -6,10 +6,11 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateTritonGPUToLLVMPatterns(
|
||||||
RewritePatternSet &patterns, int numWarps,
|
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
const Allocation *allocation, Value smem,
|
const Allocation *allocation, Value smem,
|
||||||
PatternBenefit benefit);
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||||
|
PatternBenefit benefit);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -18,7 +18,6 @@ using ::mlir::LLVM::SharedMemoryObject;
|
|||||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||||
|
|
||||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
||||||
// since it is not exposed on header files in mlir v14
|
// since it is not exposed on header files in mlir v14
|
||||||
@@ -128,7 +127,60 @@ protected:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
|
||||||
|
|
||||||
|
struct CacheKeyDenseMapInfo {
|
||||||
|
static IndexCacheKeyT getEmptyKey() {
|
||||||
|
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||||
|
return std::make_pair(
|
||||||
|
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||||
|
SmallVector<int64_t>{});
|
||||||
|
}
|
||||||
|
static IndexCacheKeyT getTombstoneKey() {
|
||||||
|
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||||
|
return std::make_pair(
|
||||||
|
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||||
|
SmallVector<int64_t>{std::numeric_limits<int64_t>::max()});
|
||||||
|
}
|
||||||
|
static unsigned getHashValue(IndexCacheKeyT key) {
|
||||||
|
return llvm::hash_combine(
|
||||||
|
mlir::hash_value(key.first),
|
||||||
|
llvm::hash_combine_range(key.second.begin(), key.second.end()));
|
||||||
|
}
|
||||||
|
static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) {
|
||||||
|
return LHS == RHS;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ConvertTritonGPUOpToLLVMPatternBase {
|
||||||
|
public:
|
||||||
|
// Two levels of value cache in emitting indices calculation:
|
||||||
|
// Key: pair<layout, shape>
|
||||||
|
struct IndexCacheInfo {
|
||||||
|
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||||
|
*baseIndexCache;
|
||||||
|
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
||||||
|
CacheKeyDenseMapInfo> *indexCache;
|
||||||
|
OpBuilder::InsertPoint *indexInsertPoint;
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter)
|
||||||
|
: converter(&typeConverter) {}
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
|
||||||
|
const Allocation *allocation,
|
||||||
|
Value smem)
|
||||||
|
: converter(&typeConverter), allocation(allocation), smem(smem) {}
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
|
||||||
|
const Allocation *allocation,
|
||||||
|
Value smem,
|
||||||
|
IndexCacheInfo indexCacheInfo)
|
||||||
|
: converter(&typeConverter), indexCacheInfo(indexCacheInfo),
|
||||||
|
allocation(allocation), smem(smem) {}
|
||||||
|
|
||||||
|
LLVMTypeConverter *getTypeConverter() const { return converter; }
|
||||||
|
|
||||||
static Value
|
static Value
|
||||||
getStructFromSharedMemoryObject(Location loc,
|
getStructFromSharedMemoryObject(Location loc,
|
||||||
const SharedMemoryObject &smemObj,
|
const SharedMemoryObject &smemObj,
|
||||||
@@ -139,25 +191,6 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
|
|||||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
template <typename SourceOp>
|
|
||||||
class ConvertTritonGPUOpToLLVMPattern
|
|
||||||
: public ConvertOpToLLVMPattern<SourceOp>,
|
|
||||||
public ConvertTritonGPUOpToLLVMPatternBase {
|
|
||||||
public:
|
|
||||||
using OpAdaptor = typename SourceOp::Adaptor;
|
|
||||||
|
|
||||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
|
||||||
|
|
||||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
|
||||||
const Allocation *allocation,
|
|
||||||
Value smem,
|
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
|
||||||
allocation(allocation), smem(smem) {}
|
|
||||||
|
|
||||||
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
|
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
|
||||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||||
@@ -169,6 +202,23 @@ public:
|
|||||||
return threadId;
|
return threadId;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Shared memory utilities
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
template <typename T>
|
||||||
|
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
|
T value) const {
|
||||||
|
|
||||||
|
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||||
|
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||||
|
auto bufferId = allocation->getBufferId(value);
|
||||||
|
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||||
|
size_t offset = allocation->getOffset(bufferId);
|
||||||
|
Value offVal = idx_val(offset);
|
||||||
|
Value base = gep(ptrTy, smem, offVal);
|
||||||
|
return base;
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
// Utilities
|
// Utilities
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
@@ -242,6 +292,116 @@ public:
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct SmallVectorKeyInfo {
|
||||||
|
static unsigned getHashValue(const SmallVector<unsigned> &key) {
|
||||||
|
return llvm::hash_combine_range(key.begin(), key.end());
|
||||||
|
}
|
||||||
|
static bool isEqual(const SmallVector<unsigned> &lhs,
|
||||||
|
const SmallVector<unsigned> &rhs) {
|
||||||
|
return lhs == rhs;
|
||||||
|
}
|
||||||
|
static SmallVector<unsigned> getEmptyKey() {
|
||||||
|
return SmallVector<unsigned>();
|
||||||
|
}
|
||||||
|
static SmallVector<unsigned> getTombstoneKey() {
|
||||||
|
return {std::numeric_limits<unsigned>::max()};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Get offsets / indices for any layout
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
const Attribute &layout,
|
||||||
|
ArrayRef<int64_t> shape) const {
|
||||||
|
IndexCacheKeyT key = std::make_pair(layout, llvm::to_vector(shape));
|
||||||
|
auto cache = indexCacheInfo.baseIndexCache;
|
||||||
|
assert(cache && "baseIndexCache is nullptr");
|
||||||
|
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||||
|
if (cache->count(key) > 0) {
|
||||||
|
return cache->lookup(key);
|
||||||
|
} else {
|
||||||
|
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||||
|
restoreInsertionPointIfSet(insertPt, rewriter);
|
||||||
|
SmallVector<Value> result;
|
||||||
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
result =
|
||||||
|
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||||
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
if (mmaLayout.isVolta())
|
||||||
|
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
|
||||||
|
if (mmaLayout.isAmpere())
|
||||||
|
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
||||||
|
}
|
||||||
|
cache->insert(std::make_pair(key, result));
|
||||||
|
*insertPt = rewriter.saveInsertionPoint();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<SmallVector<unsigned>>
|
||||||
|
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||||
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||||
|
return emitOffsetForBlockedLayout(blockedLayout, shape);
|
||||||
|
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
if (mmaLayout.isVolta())
|
||||||
|
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
|
||||||
|
if (mmaLayout.isAmpere())
|
||||||
|
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
|
||||||
|
}
|
||||||
|
llvm_unreachable("unsupported emitOffsetForLayout");
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Emit indices
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||||
|
ConversionPatternRewriter &b,
|
||||||
|
const Attribute &layout,
|
||||||
|
ArrayRef<int64_t> shape) const {
|
||||||
|
IndexCacheKeyT key(layout, llvm::to_vector(shape));
|
||||||
|
auto cache = indexCacheInfo.indexCache;
|
||||||
|
assert(cache && "indexCache is nullptr");
|
||||||
|
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||||
|
if (cache->count(key) > 0) {
|
||||||
|
return cache->lookup(key);
|
||||||
|
} else {
|
||||||
|
ConversionPatternRewriter::InsertionGuard guard(b);
|
||||||
|
restoreInsertionPointIfSet(insertPt, b);
|
||||||
|
SmallVector<SmallVector<Value>> result;
|
||||||
|
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
result = emitIndicesForDistributedLayout(loc, b, blocked, shape);
|
||||||
|
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
result = emitIndicesForDistributedLayout(loc, b, mma, shape);
|
||||||
|
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||||
|
result = emitIndicesForSliceLayout(loc, b, slice, shape);
|
||||||
|
} else {
|
||||||
|
llvm_unreachable(
|
||||||
|
"emitIndices for layouts other than blocked & slice not "
|
||||||
|
"implemented yet");
|
||||||
|
}
|
||||||
|
cache->insert(std::make_pair(key, result));
|
||||||
|
*insertPt = b.saveInsertionPoint();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void restoreInsertionPointIfSet(OpBuilder::InsertPoint *insertPt,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
if (insertPt->isSet()) {
|
||||||
|
rewriter.restoreInsertionPoint(*insertPt);
|
||||||
|
} else {
|
||||||
|
auto func =
|
||||||
|
rewriter.getInsertionPoint()->getParentOfType<LLVM::LLVMFuncOp>();
|
||||||
|
rewriter.setInsertionPointToStart(&func.getBody().front());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
// Blocked layout indices
|
// Blocked layout indices
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
@@ -411,38 +571,6 @@ public:
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
// Get offsets / indices for any layout
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
|
|
||||||
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
|
||||||
ConversionPatternRewriter &rewriter,
|
|
||||||
const Attribute &layout,
|
|
||||||
ArrayRef<int64_t> shape) const {
|
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
|
||||||
return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
|
||||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
||||||
if (mmaLayout.isVolta())
|
|
||||||
return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
|
|
||||||
if (mmaLayout.isAmpere())
|
|
||||||
return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
|
|
||||||
}
|
|
||||||
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<SmallVector<unsigned>>
|
|
||||||
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
|
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
|
||||||
return emitOffsetForBlockedLayout(blockedLayout, shape);
|
|
||||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
||||||
if (mmaLayout.isVolta())
|
|
||||||
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
|
|
||||||
if (mmaLayout.isAmpere())
|
|
||||||
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
|
|
||||||
}
|
|
||||||
llvm_unreachable("unsupported emitOffsetForLayout");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit indices calculation within each ConversionPattern, and returns a
|
// Emit indices calculation within each ConversionPattern, and returns a
|
||||||
// [elemsPerThread X rank] index matrix.
|
// [elemsPerThread X rank] index matrix.
|
||||||
|
|
||||||
@@ -470,22 +598,6 @@ public:
|
|||||||
return multiDimIdx;
|
return multiDimIdx;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SmallVectorKeyInfo {
|
|
||||||
static unsigned getHashValue(const SmallVector<unsigned> &key) {
|
|
||||||
return llvm::hash_combine_range(key.begin(), key.end());
|
|
||||||
}
|
|
||||||
static bool isEqual(const SmallVector<unsigned> &lhs,
|
|
||||||
const SmallVector<unsigned> &rhs) {
|
|
||||||
return lhs == rhs;
|
|
||||||
}
|
|
||||||
static SmallVector<unsigned> getEmptyKey() {
|
|
||||||
return SmallVector<unsigned>();
|
|
||||||
}
|
|
||||||
static SmallVector<unsigned> getTombstoneKey() {
|
|
||||||
return {std::numeric_limits<unsigned>::max()};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
SmallVector<SmallVector<Value>>
|
SmallVector<SmallVector<Value>>
|
||||||
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
const SliceEncodingAttr &sliceLayout,
|
const SliceEncodingAttr &sliceLayout,
|
||||||
@@ -505,46 +617,45 @@ public:
|
|||||||
return resultIndices;
|
return resultIndices;
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
// Emit indices
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
|
||||||
ConversionPatternRewriter &b,
|
|
||||||
const Attribute &layout,
|
|
||||||
ArrayRef<int64_t> shape) const {
|
|
||||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
||||||
return emitIndicesForDistributedLayout(loc, b, blocked, shape);
|
|
||||||
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
|
||||||
return emitIndicesForDistributedLayout(loc, b, mma, shape);
|
|
||||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
|
||||||
return emitIndicesForSliceLayout(loc, b, slice, shape);
|
|
||||||
} else {
|
|
||||||
assert(0 && "emitIndices for layouts other than blocked & slice not "
|
|
||||||
"implemented yet");
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
// Shared memory utilities
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
template <typename T>
|
|
||||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
|
||||||
T value) const {
|
|
||||||
|
|
||||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
|
||||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
|
||||||
auto bufferId = allocation->getBufferId(value);
|
|
||||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
|
||||||
size_t offset = allocation->getOffset(bufferId);
|
|
||||||
Value offVal = idx_val(offset);
|
|
||||||
Value base = gep(ptrTy, smem, offVal);
|
|
||||||
return base;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
LLVMTypeConverter *converter;
|
||||||
const Allocation *allocation;
|
const Allocation *allocation;
|
||||||
Value smem;
|
Value smem;
|
||||||
|
IndexCacheInfo indexCacheInfo;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SourceOp>
|
||||||
|
class ConvertTritonGPUOpToLLVMPattern
|
||||||
|
: public ConvertOpToLLVMPattern<SourceOp>,
|
||||||
|
public ConvertTritonGPUOpToLLVMPatternBase {
|
||||||
|
public:
|
||||||
|
using OpAdaptor = typename SourceOp::Adaptor;
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||||
|
const Allocation *allocation,
|
||||||
|
Value smem,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {}
|
||||||
|
|
||||||
|
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||||
|
const Allocation *allocation,
|
||||||
|
Value smem,
|
||||||
|
IndexCacheInfo indexCacheInfo,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem,
|
||||||
|
indexCacheInfo) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
LLVMTypeConverter *getTypeConverter() const {
|
||||||
|
return ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -170,16 +170,20 @@ public:
|
|||||||
// We set a higher benefit here to ensure triton's patterns runs before
|
// We set a higher benefit here to ensure triton's patterns runs before
|
||||||
// arith patterns for some encoding not supported by the community
|
// arith patterns for some encoding not supported by the community
|
||||||
// patterns.
|
// patterns.
|
||||||
|
OpBuilder::InsertPoint indexInsertPoint;
|
||||||
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
|
||||||
|
&baseIndexCache, &indexCache, &indexInsertPoint};
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
|
||||||
// Normal conversions
|
// Normal conversions
|
||||||
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
/*benefit=*/10);
|
indexCacheInfo, /*benefit=*/10);
|
||||||
// ConvertLayoutOp
|
// ConvertLayoutOp
|
||||||
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
/*benefit=*/10);
|
indexCacheInfo, /*benefit=*/10);
|
||||||
// DotOp
|
// DotOp
|
||||||
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
@@ -191,11 +195,11 @@ public:
|
|||||||
// LoadStoreOp
|
// LoadStoreOp
|
||||||
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
/*benefit=*/10);
|
indexCacheInfo, /*benefit=*/10);
|
||||||
// ReduceOp
|
// ReduceOp
|
||||||
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
/*benefit=*/10);
|
indexCacheInfo, /*benefit=*/10);
|
||||||
// ViewOp
|
// ViewOp
|
||||||
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
axisInfoAnalysis, &allocation, smem,
|
axisInfoAnalysis, &allocation, smem,
|
||||||
@@ -215,6 +219,13 @@ public:
|
|||||||
private:
|
private:
|
||||||
Value smem;
|
Value smem;
|
||||||
|
|
||||||
|
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
|
||||||
|
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||||
|
baseIndexCache;
|
||||||
|
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
||||||
|
CacheKeyDenseMapInfo>
|
||||||
|
indexCache;
|
||||||
|
|
||||||
int computeCapability{};
|
int computeCapability{};
|
||||||
|
|
||||||
void initSharedMemory(size_t size,
|
void initSharedMemory(size_t size,
|
||||||
|
@@ -4,11 +4,13 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
|
using ::mlir::LLVM::DotOpFMAConversionHelper;
|
||||||
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
||||||
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
|
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
|
||||||
using ::mlir::LLVM::getElementsFromStruct;
|
using ::mlir::LLVM::getElementsFromStruct;
|
||||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||||
using ::mlir::LLVM::getStructFromElements;
|
using ::mlir::LLVM::getStructFromElements;
|
||||||
|
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||||
using ::mlir::triton::gpu::getElemsPerThread;
|
using ::mlir::triton::gpu::getElemsPerThread;
|
||||||
|
|
||||||
struct SplatOpConversion
|
struct SplatOpConversion
|
||||||
@@ -38,6 +40,11 @@ struct SplatOpConversion
|
|||||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||||
|
|
||||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||||
|
} else if (auto dotLayout =
|
||||||
|
tensorTy.getEncoding()
|
||||||
|
.dyn_cast<triton::gpu::DotOperandEncodingAttr>()) {
|
||||||
|
return convertSplatLikeOpWithDotOperandLayout(
|
||||||
|
dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
|
||||||
} else if (auto mmaLayout =
|
} else if (auto mmaLayout =
|
||||||
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
|
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
|
||||||
return convertSplatLikeOpWithMmaLayout(
|
return convertSplatLikeOpWithMmaLayout(
|
||||||
@@ -48,6 +55,38 @@ struct SplatOpConversion
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value convertSplatLikeOpWithDotOperandLayout(
|
||||||
|
const triton::gpu::DotOperandEncodingAttr &layout, Type resType,
|
||||||
|
Type elemType, Value constVal, TypeConverter *typeConverter,
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc) {
|
||||||
|
auto tensorTy = resType.cast<RankedTensorType>();
|
||||||
|
auto shape = tensorTy.getShape();
|
||||||
|
auto parent = layout.getParent();
|
||||||
|
int numElems{};
|
||||||
|
if (auto mmaLayout = parent.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
if (mmaLayout.isAmpere()) {
|
||||||
|
numElems = layout.getOpIdx() == 0
|
||||||
|
? MMA16816ConversionHelper::getANumElemsPerThread(
|
||||||
|
tensorTy, mmaLayout.getWarpsPerCTA()[0])
|
||||||
|
: MMA16816ConversionHelper::getBNumElemsPerThread(
|
||||||
|
tensorTy, mmaLayout.getWarpsPerCTA()[1]);
|
||||||
|
} else if (mmaLayout.isVolta()) {
|
||||||
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||||
|
numElems = layout.getOpIdx() == 0
|
||||||
|
? helper.numElemsPerThreadA(shape, {0, 1})
|
||||||
|
: helper.numElemsPerThreadB(shape, {0, 1});
|
||||||
|
}
|
||||||
|
} else if (auto blockedLayout = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout);
|
||||||
|
} else {
|
||||||
|
assert(false && "Unsupported layout found");
|
||||||
|
}
|
||||||
|
auto structTy = LLVM::LLVMStructType::getLiteral(
|
||||||
|
rewriter.getContext(), SmallVector<Type>(numElems, elemType));
|
||||||
|
return getStructFromElements(loc, SmallVector<Value>(numElems, constVal),
|
||||||
|
rewriter, structTy);
|
||||||
|
}
|
||||||
|
|
||||||
static Value convertSplatLikeOpWithMmaLayout(
|
static Value convertSplatLikeOpWithMmaLayout(
|
||||||
const MmaEncodingAttr &layout, Type resType, Type elemType,
|
const MmaEncodingAttr &layout, Type resType, Type elemType,
|
||||||
Value constVal, TypeConverter *typeConverter,
|
Value constVal, TypeConverter *typeConverter,
|
||||||
|
@@ -257,6 +257,11 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
bool isaDistributedLayout(const Attribute &layout) {
|
||||||
|
return layout.isa<BlockedEncodingAttr>() || layout.isa<MmaEncodingAttr>() ||
|
||||||
|
layout.isa<SliceEncodingAttr>();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -23,6 +23,10 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
namespace {
|
namespace {
|
||||||
#include "TritonGPUCombine.inc"
|
#include "TritonGPUCombine.inc"
|
||||||
|
using triton::DotOp;
|
||||||
|
using triton::gpu::ConvertLayoutOp;
|
||||||
|
using triton::gpu::DotOperandEncodingAttr;
|
||||||
|
using triton::gpu::MmaEncodingAttr;
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
//
|
//
|
||||||
@@ -1020,6 +1024,7 @@ public:
|
|||||||
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||||
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
||||||
@@ -1061,7 +1066,8 @@ public:
|
|||||||
auto dotOp = cast<triton::DotOp>(op);
|
auto dotOp = cast<triton::DotOp>(op);
|
||||||
// TODO: Check data-types and SM compatibility
|
// TODO: Check data-types and SM compatibility
|
||||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||||
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
if (!oldRetType.getEncoding() ||
|
||||||
|
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto AType = dotOp.getOperand(0).getType().cast<RankedTensorType>();
|
auto AType = dotOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||||
@@ -1171,7 +1177,8 @@ public:
|
|||||||
for (size_t i = 0; i < newInitArgs.size(); i++) {
|
for (size_t i = 0; i < newInitArgs.size(); i++) {
|
||||||
auto initArg = newInitArgs[i];
|
auto initArg = newInitArgs[i];
|
||||||
auto regionArg = forOp.getRegionIterArgs()[i];
|
auto regionArg = forOp.getRegionIterArgs()[i];
|
||||||
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
|
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType() ||
|
||||||
|
newInitArgs[i].getType() != forOp.getResultTypes()[i]) {
|
||||||
shouldRematerialize = true;
|
shouldRematerialize = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -1187,15 +1194,207 @@ public:
|
|||||||
BlockAndValueMapping mapping;
|
BlockAndValueMapping mapping;
|
||||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||||
|
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||||
|
|
||||||
for (Operation &op : forOp.getBody()->getOperations()) {
|
for (Operation &op : forOp.getBody()->getOperations()) {
|
||||||
Operation *newOp = rewriter.clone(op, mapping);
|
rewriter.clone(op, mapping);
|
||||||
}
|
}
|
||||||
rewriter.replaceOp(forOp, newForOp.getResults());
|
rewriter.replaceOp(forOp, newForOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// This pattern collects the wrong Mma those need to update and create the right
|
||||||
|
// ones for each.
|
||||||
|
class CollectMmaToUpdateForVolta : public mlir::RewritePattern {
|
||||||
|
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CollectMmaToUpdateForVolta(
|
||||||
|
mlir::MLIRContext *ctx,
|
||||||
|
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
|
||||||
|
: mlir::RewritePattern(triton::DotOp::getOperationName(), 1, ctx),
|
||||||
|
mmaToUpdate(mmaToUpdate) {}
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
matchAndRewrite(mlir::Operation *op,
|
||||||
|
mlir::PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
auto dotOp = cast<triton::DotOp>(op);
|
||||||
|
auto *ctx = dotOp->getContext();
|
||||||
|
auto AT = dotOp.a().getType().cast<RankedTensorType>();
|
||||||
|
auto BT = dotOp.b().getType().cast<RankedTensorType>();
|
||||||
|
auto DT = dotOp.d().getType().cast<RankedTensorType>();
|
||||||
|
if (!DT.getEncoding())
|
||||||
|
return failure();
|
||||||
|
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
|
||||||
|
if (!(mmaLayout && mmaLayout.isVolta()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Has processed.
|
||||||
|
if (mmaToUpdate.count(mmaLayout))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
|
||||||
|
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
|
||||||
|
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
|
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
|
auto [isARow_, isBRow_, isAVec4, isBVec4] =
|
||||||
|
mmaLayout.decodeVoltaLayoutStates();
|
||||||
|
if (isARow_ == isARow && isBRow_ == isBRow) {
|
||||||
|
return failure(); // No need to update
|
||||||
|
}
|
||||||
|
|
||||||
|
auto newMmaLayout = MmaEncodingAttr::get(
|
||||||
|
ctx, mmaLayout.getVersionMajor(), mmaLayout.getWarpsPerCTA(),
|
||||||
|
AT.getShape(), BT.getShape(), isARow, isBRow);
|
||||||
|
|
||||||
|
// Collect the wrong MMA Layouts, and mark need to update.
|
||||||
|
mmaToUpdate.try_emplace(mmaLayout, newMmaLayout);
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Correct the versionMinor field in MmaEncodingAttr for Volta.
|
||||||
|
class UpdateMMAVersionMinorForVolta : public mlir::RewritePattern {
|
||||||
|
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
|
||||||
|
enum class Kind {
|
||||||
|
kUnk,
|
||||||
|
kCvtToMma,
|
||||||
|
kCvtToDotOp,
|
||||||
|
kDot,
|
||||||
|
kConstant,
|
||||||
|
};
|
||||||
|
mutable Kind rewriteKind{Kind::kUnk};
|
||||||
|
|
||||||
|
public:
|
||||||
|
UpdateMMAVersionMinorForVolta(
|
||||||
|
mlir::MLIRContext *ctx, llvm::StringRef opName,
|
||||||
|
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
|
||||||
|
: RewritePattern(opName, 1 /*benefit*/, ctx), mmaToUpdate(mmaToUpdate) {}
|
||||||
|
|
||||||
|
LogicalResult match(Operation *op) const override {
|
||||||
|
MmaEncodingAttr mma;
|
||||||
|
if (mmaToUpdate.empty())
|
||||||
|
return failure();
|
||||||
|
if (op->getNumResults() != 1)
|
||||||
|
return failure();
|
||||||
|
auto tensorTy = op->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!tensorTy)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// ConvertLayoutOp
|
||||||
|
if (auto cvt = llvm::dyn_cast<ConvertLayoutOp>(op)) {
|
||||||
|
// cvt X -> dot_operand
|
||||||
|
if (auto dotOperand =
|
||||||
|
tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>()) {
|
||||||
|
mma = dotOperand.getParent().dyn_cast<MmaEncodingAttr>();
|
||||||
|
rewriteKind = Kind::kCvtToDotOp;
|
||||||
|
if (mma && mmaToUpdate.count(mma))
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if ((mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>())) {
|
||||||
|
// cvt X -> mma
|
||||||
|
rewriteKind = Kind::kCvtToMma;
|
||||||
|
if (mma && mmaToUpdate.count(mma))
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
} else if (auto dot = llvm::dyn_cast<DotOp>(op)) {
|
||||||
|
// DotOp
|
||||||
|
mma = dot.d()
|
||||||
|
.getType()
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getEncoding()
|
||||||
|
.dyn_cast<MmaEncodingAttr>();
|
||||||
|
rewriteKind = Kind::kDot;
|
||||||
|
} else if (auto constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
|
||||||
|
// ConstantOp
|
||||||
|
mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
|
||||||
|
rewriteKind = Kind::kConstant;
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(mma && mmaToUpdate.count(mma));
|
||||||
|
}
|
||||||
|
|
||||||
|
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||||
|
switch (rewriteKind) {
|
||||||
|
case Kind::kDot:
|
||||||
|
rewriteDot(op, rewriter);
|
||||||
|
break;
|
||||||
|
case Kind::kConstant:
|
||||||
|
rewriteConstant(op, rewriter);
|
||||||
|
break;
|
||||||
|
case Kind::kCvtToDotOp:
|
||||||
|
rewriteCvtDotOp(op, rewriter);
|
||||||
|
break;
|
||||||
|
case Kind::kCvtToMma:
|
||||||
|
rewriteCvtToMma(op, rewriter);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
llvm::report_fatal_error("Not supported rewrite kind");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void rewriteCvtDotOp(Operation *op, PatternRewriter &rewriter) const {
|
||||||
|
auto *ctx = op->getContext();
|
||||||
|
auto cvt = llvm::cast<ConvertLayoutOp>(op);
|
||||||
|
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
|
||||||
|
auto dotOperand = tensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||||
|
MmaEncodingAttr newMma =
|
||||||
|
mmaToUpdate.lookup(dotOperand.getParent().cast<MmaEncodingAttr>());
|
||||||
|
auto newDotOperand = DotOperandEncodingAttr::get(
|
||||||
|
ctx, dotOperand.getOpIdx(), newMma, dotOperand.getIsMMAv1Row());
|
||||||
|
auto newTensorTy = RankedTensorType::get(
|
||||||
|
tensorTy.getShape(), tensorTy.getElementType(), newDotOperand);
|
||||||
|
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
|
||||||
|
cvt.getOperand());
|
||||||
|
}
|
||||||
|
|
||||||
|
void rewriteDot(Operation *op, PatternRewriter &rewriter) const {
|
||||||
|
auto *ctx = op->getContext();
|
||||||
|
auto dot = llvm::cast<DotOp>(op);
|
||||||
|
auto tensorTy = dot.d().getType().cast<RankedTensorType>();
|
||||||
|
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||||
|
auto newMma = mmaToUpdate.lookup(mma);
|
||||||
|
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
|
||||||
|
tensorTy.getElementType(), newMma);
|
||||||
|
rewriter.replaceOpWithNewOp<DotOp>(op, newTensorTy, dot.a(), dot.b(),
|
||||||
|
dot.c(), dot.allowTF32());
|
||||||
|
}
|
||||||
|
|
||||||
|
void rewriteCvtToMma(Operation *op, PatternRewriter &rewriter) const {
|
||||||
|
auto *ctx = op->getContext();
|
||||||
|
auto cvt = llvm::cast<ConvertLayoutOp>(op);
|
||||||
|
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
|
||||||
|
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||||
|
auto newMma = mmaToUpdate.lookup(mma);
|
||||||
|
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
|
||||||
|
tensorTy.getElementType(), newMma);
|
||||||
|
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
|
||||||
|
cvt.getOperand());
|
||||||
|
}
|
||||||
|
|
||||||
|
void rewriteConstant(Operation *op, PatternRewriter &rewriter) const {
|
||||||
|
auto *ctx = op->getContext();
|
||||||
|
auto constant = llvm::cast<arith::ConstantOp>(op);
|
||||||
|
auto tensorTy = constant.getResult().getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||||
|
auto newMma = mmaToUpdate.lookup(mma);
|
||||||
|
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
|
||||||
|
tensorTy.getElementType(), newMma);
|
||||||
|
if (auto attr = constant.getValue().dyn_cast<SplatElementsAttr>()) {
|
||||||
|
auto newRet =
|
||||||
|
SplatElementsAttr::get(newTensorTy, attr.getSplatValue<Attribute>());
|
||||||
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newTensorTy, newRet);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(false && "Not supported ConstantOp value type");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
#define GEN_PASS_CLASSES
|
#define GEN_PASS_CLASSES
|
||||||
@@ -1230,6 +1429,28 @@ public:
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::DenseMap<MmaEncodingAttr, MmaEncodingAttr> mmaToUpdate;
|
||||||
|
{
|
||||||
|
mlir::RewritePatternSet patterns(context);
|
||||||
|
patterns.add<CollectMmaToUpdateForVolta>(context, mmaToUpdate);
|
||||||
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
mlir::RewritePatternSet patterns(context);
|
||||||
|
patterns.add<UpdateMMAVersionMinorForVolta>(
|
||||||
|
context, DotOp::getOperationName(), mmaToUpdate);
|
||||||
|
patterns.add<UpdateMMAVersionMinorForVolta>(
|
||||||
|
context, ConvertLayoutOp::getOperationName(), mmaToUpdate);
|
||||||
|
patterns.add<UpdateMMAVersionMinorForVolta>(
|
||||||
|
context, arith::ConstantOp::getOperationName(), mmaToUpdate);
|
||||||
|
mlir::GreedyRewriteConfig config;
|
||||||
|
config.useTopDownTraversal = true;
|
||||||
|
|
||||||
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed())
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
mlir::RewritePatternSet loopFixup(context);
|
mlir::RewritePatternSet loopFixup(context);
|
||||||
loopFixup.add<FixupLoop>(context);
|
loopFixup.add<FixupLoop>(context);
|
||||||
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {
|
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {
|
||||||
|
@@ -141,10 +141,10 @@ class CMakeBuild(build_ext):
|
|||||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||||
"-DTRITON_BUILD_TUTORIALS=OFF",
|
"-DTRITON_BUILD_TUTORIALS=OFF",
|
||||||
"-DTRITON_BUILD_PYTHON_MODULE=ON",
|
"-DTRITON_BUILD_PYTHON_MODULE=ON",
|
||||||
# '-DPYTHON_EXECUTABLE=' + sys.executable,
|
"-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
|
||||||
'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON",
|
||||||
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
|
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
|
||||||
"-DLLVM_EXTERNAL_LIT=" + lit_dir
|
"-DLLVM_EXTERNAL_LIT=" + lit_dir,
|
||||||
] + thirdparty_cmake_args
|
] + thirdparty_cmake_args
|
||||||
|
|
||||||
# configuration
|
# configuration
|
||||||
|
@@ -491,10 +491,9 @@ def make_ptr_str(name, shape):
|
|||||||
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
||||||
@pytest.mark.parametrize("expr, dtype_str", [
|
@pytest.mark.parametrize("expr, dtype_str", [
|
||||||
(f'x[{s}]', d)
|
(f'x[{s}]', d)
|
||||||
for s in ['None, :', ':, None']
|
for s in ['None, :', ':, None',
|
||||||
# FIXME: 3d indexing doesn't work
|
'None, :, :',
|
||||||
#'None, :, :',
|
':, :, None']
|
||||||
# ':, :, None']
|
|
||||||
for d in ['int32', 'uint32', 'uint16']
|
for d in ['int32', 'uint32', 'uint16']
|
||||||
])
|
])
|
||||||
def test_index1d(expr, dtype_str, device='cuda'):
|
def test_index1d(expr, dtype_str, device='cuda'):
|
||||||
@@ -1228,20 +1227,20 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
|||||||
elif dtype == 'int8':
|
elif dtype == 'int8':
|
||||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||||
|
|
||||||
# FIXME: Unsupported layout found in ConvertSplatLikeOp
|
|
||||||
# def test_dot_without_load():
|
def test_dot_without_load():
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(out):
|
def kernel(out):
|
||||||
# pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
# a = tl.zeros((32, 32), tl.float32)
|
a = tl.zeros((32, 32), tl.float32)
|
||||||
# b = tl.zeros((32, 32), tl.float32)
|
b = tl.zeros((32, 32), tl.float32)
|
||||||
# c = tl.zeros((32, 32), tl.float32)
|
c = tl.zeros((32, 32), tl.float32)
|
||||||
# c = tl.dot(a, b)
|
c = tl.dot(a, b)
|
||||||
# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||||
# tl.store(pout, c)
|
tl.store(pout, c)
|
||||||
#
|
|
||||||
# out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
||||||
# kernel[(1,)](out)
|
kernel[(1,)](out)
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test arange
|
# test arange
|
||||||
|
@@ -20,6 +20,8 @@ from .core import (
|
|||||||
atomic_xor,
|
atomic_xor,
|
||||||
bfloat16,
|
bfloat16,
|
||||||
block_type,
|
block_type,
|
||||||
|
broadcast,
|
||||||
|
broadcast_to,
|
||||||
cat,
|
cat,
|
||||||
cdiv,
|
cdiv,
|
||||||
constexpr,
|
constexpr,
|
||||||
@@ -105,6 +107,8 @@ __all__ = [
|
|||||||
"atomic_xor",
|
"atomic_xor",
|
||||||
"bfloat16",
|
"bfloat16",
|
||||||
"block_type",
|
"block_type",
|
||||||
|
"broadcast",
|
||||||
|
"broadcast_to",
|
||||||
"builtin",
|
"builtin",
|
||||||
"cat",
|
"cat",
|
||||||
"cdiv",
|
"cdiv",
|
||||||
|
@@ -596,11 +596,9 @@ class tensor:
|
|||||||
if isinstance(slices, slice):
|
if isinstance(slices, slice):
|
||||||
slices = [slices]
|
slices = [slices]
|
||||||
ret = self
|
ret = self
|
||||||
n_inserted = 0
|
|
||||||
for dim, sl in enumerate(slices):
|
for dim, sl in enumerate(slices):
|
||||||
if isinstance(sl, constexpr) and sl.value is None:
|
if isinstance(sl, constexpr) and sl.value is None:
|
||||||
ret = semantic.expand_dims(ret, dim + n_inserted, _builder)
|
ret = semantic.expand_dims(ret, dim, _builder)
|
||||||
n_inserted += 1
|
|
||||||
elif sl == slice(None, None, None):
|
elif sl == slice(None, None, None):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
@@ -997,20 +997,61 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
|||||||
// -----
|
// -----
|
||||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||||
|
// CHECK: nvvm.read.ptx.sreg.nctaid.x
|
||||||
|
// CHECK: nvvm.read.ptx.sreg.nctaid.y
|
||||||
|
// CHECK: nvvm.read.ptx.sreg.nctaid.z
|
||||||
|
%blockdimx = tt.get_num_programs {axis=0:i32} : i32
|
||||||
|
%blockdimy = tt.get_num_programs {axis=1:i32} : i32
|
||||||
|
%blockdimz = tt.get_num_programs {axis=2:i32} : i32
|
||||||
|
%v0 = arith.addi %blockdimx, %blockdimy : i32
|
||||||
|
%v1 = arith.addi %v0, %blockdimz : i32
|
||||||
|
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
||||||
|
tt.store %a, %0 : tensor<32xi32, #blocked0>
|
||||||
|
|
||||||
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
return
|
||||||
// CHECK: nvvm.read.ptx.sreg.nctaid.x
|
}
|
||||||
// CHECK: nvvm.read.ptx.sreg.nctaid.y
|
|
||||||
// CHECK: nvvm.read.ptx.sreg.nctaid.z
|
|
||||||
%blockdimx = tt.get_num_programs {axis=0:i32} : i32
|
|
||||||
%blockdimy = tt.get_num_programs {axis=1:i32} : i32
|
|
||||||
%blockdimz = tt.get_num_programs {axis=2:i32} : i32
|
|
||||||
%v0 = arith.addi %blockdimx, %blockdimy : i32
|
|
||||||
%v1 = arith.addi %v0, %blockdimz : i32
|
|
||||||
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
|
||||||
tt.store %a, %0 : tensor<32xi32, #blocked0>
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
// CHECK-LABEL: test_index_cache
|
||||||
|
func @test_index_cache() {
|
||||||
|
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||||
|
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||||
|
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
// CHECK-LABEL: test_base_index_cache
|
||||||
|
func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
|
||||||
|
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||||
|
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
|
||||||
|
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
// CHECK-LABEL: test_index_cache_different_block
|
||||||
|
func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
|
||||||
|
// CHECK: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||||
|
scf.if %arg1 {
|
||||||
|
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
|
||||||
|
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
@@ -1,4 +1,4 @@
|
|||||||
// RUN: triton-opt %s -tritongpu-combine 2>&1 | FileCheck %s
|
// RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s
|
||||||
|
|
||||||
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||||
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||||
@@ -7,7 +7,6 @@
|
|||||||
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||||
// CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
// CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||||
|
|
||||||
func @cst() -> tensor<1024xi32, #layout1> {
|
func @cst() -> tensor<1024xi32, #layout1> {
|
||||||
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
|
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
|
||||||
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||||
@@ -62,9 +61,9 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
|||||||
// CHECK-LABEL: transpose
|
// CHECK-LABEL: transpose
|
||||||
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||||
// CHECK-NOT: triton_gpu.convert_layout
|
// CHECK-NOT: triton_gpu.convert_layout
|
||||||
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
||||||
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
||||||
// CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]>
|
// CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64xf32, [[col_layout]]>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||||
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||||
@@ -184,3 +183,32 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
|
|||||||
tt.store %21, %22 : tensor<256xf32, #layout1>
|
tt.store %21, %22 : tensor<256xf32, #layout1>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// check the UpdateMMAVersionMinorForVolta pattern
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||||
|
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
|
||||||
|
#mma0 = #triton_gpu.mma<{versionMajor=1, versionMinor=0, warpsPerCTA=[1,1]}>
|
||||||
|
// Here, the isMMAv1Row of a and b's dot_operands mismatch #mma0's versionMinor,
|
||||||
|
// and the pattern should update the versionMinor.
|
||||||
|
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}>
|
||||||
|
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
|
||||||
|
// It creates a new MMA layout to fit with $a and $b's dot_operand
|
||||||
|
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 11, warpsPerCTA = [1, 1]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
// CHECK-LABEL: dot_mmav1
|
||||||
|
func @dot_mmav1(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) -> tensor<16x16xf32, #blocked0> {
|
||||||
|
%C = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked0>
|
||||||
|
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_a>
|
||||||
|
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_b>
|
||||||
|
%CC = triton_gpu.convert_layout %C : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #mma0>
|
||||||
|
|
||||||
|
// CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, %cst {allowTF32 = true} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[new_mma]], isMMAv1Row = true}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[new_mma]], isMMAv1Row = true}>> -> tensor<16x16xf32, [[new_mma]]>
|
||||||
|
%D = tt.dot %AA, %BB, %CC {allowTF32 = true} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
|
||||||
|
%res = triton_gpu.convert_layout %D : (tensor<16x16xf32, #mma0>) -> tensor<16x16xf32, #blocked0>
|
||||||
|
|
||||||
|
return %res : tensor<16x16xf32, #blocked0>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user