[Triton-MLIR][Backend] Support layout conversion between mmaLayout and blockedLayout (#693)
This commit is contained in:
@@ -29,7 +29,9 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
@@ -350,6 +352,13 @@ public:
|
||||
return threadId;
|
||||
}
|
||||
|
||||
Value createIndexConst(ConversionPatternRewriter &rewriter, Location loc,
|
||||
int64_t value) const {
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, this->getTypeConverter()->getIndexType(),
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), value));
|
||||
}
|
||||
|
||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value linear,
|
||||
ArrayRef<unsigned> shape,
|
||||
@@ -423,7 +432,7 @@ public:
|
||||
auto order = blocked_layout.getOrder();
|
||||
unsigned rank = shape.size();
|
||||
|
||||
// step 1, delinearize threadId to get the base index
|
||||
// delinearize threadId to get the base index
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
SmallVector<Value> multiDimThreadId =
|
||||
@@ -455,6 +464,13 @@ public:
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
// ongoing
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
const Attribute &layout,
|
||||
@@ -1459,9 +1475,11 @@ public:
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
Attribute dstLayout = dstTy.getEncoding();
|
||||
if ((!srcLayout.isa<BlockedEncodingAttr>()) ||
|
||||
(!dstLayout.isa<BlockedEncodingAttr>())) {
|
||||
// TODO: not implemented
|
||||
if ((!srcLayout.isa<BlockedEncodingAttr>() &&
|
||||
!srcLayout.isa<MmaEncodingAttr>()) ||
|
||||
(!dstLayout.isa<BlockedEncodingAttr>() &&
|
||||
!dstLayout.isa<MmaEncodingAttr>())) {
|
||||
// TODO: to be implemented
|
||||
return failure();
|
||||
}
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
@@ -1471,31 +1489,6 @@ public:
|
||||
|
||||
auto shape = dstTy.getShape();
|
||||
unsigned rank = dstTy.getRank();
|
||||
auto getContigPerThread = [&](const Attribute &layout,
|
||||
unsigned d) -> unsigned {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getSizePerThread()[d];
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getContigPerThread");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
auto getAccumElemsPerThread = [&](const Attribute &layout) -> unsigned {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return product<unsigned>(blockedLayout.getSizePerThread());
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
auto getOrder = [&](const Attribute &layout) -> ArrayRef<unsigned> {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getOrder();
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
|
||||
return {};
|
||||
}
|
||||
};
|
||||
SmallVector<unsigned> numReplicates(rank);
|
||||
SmallVector<unsigned> inNumCTAsEachRep(rank);
|
||||
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
||||
@@ -1517,7 +1510,6 @@ public:
|
||||
}
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
||||
unsigned accumInSizePerThread = getAccumElemsPerThread(srcLayout);
|
||||
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
|
||||
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
||||
unsigned inVec = 0;
|
||||
@@ -1530,19 +1522,21 @@ public:
|
||||
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
||||
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
||||
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
||||
if (auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
processReplicaBlocked(loc, rewriter, /*stNotRd*/ true, srcTy,
|
||||
inNumCTAsEachRep, multiDimRepId, inVec,
|
||||
paddedRepShape, outOrd, vals, smemBase);
|
||||
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
||||
srcLayout.isa<MmaEncodingAttr>()) {
|
||||
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
|
||||
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
|
||||
smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with input layout not implemented");
|
||||
return failure();
|
||||
}
|
||||
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
processReplicaBlocked(loc, rewriter, /*stNotRd*/ false, dstTy,
|
||||
outNumCTAsEachRep, multiDimRepId, outVec,
|
||||
paddedRepShape, outOrd, outVals, smemBase);
|
||||
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
||||
dstLayout.isa<MmaEncodingAttr>()) {
|
||||
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
|
||||
outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape,
|
||||
outOrd, outVals, smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with output layout not implemented");
|
||||
return failure();
|
||||
@@ -1568,30 +1562,58 @@ private:
|
||||
return result;
|
||||
};
|
||||
|
||||
void processReplicaBlocked(Location loc, ConversionPatternRewriter &rewriter,
|
||||
bool stNotRd, RankedTensorType type,
|
||||
ArrayRef<unsigned> numCTAsEachRep,
|
||||
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
||||
ArrayRef<unsigned> paddedRepShape,
|
||||
ArrayRef<unsigned> outOrd,
|
||||
SmallVector<Value> &vals, Value smemBase) const {
|
||||
// shared memory access for blocked or mma layout
|
||||
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
|
||||
bool stNotRd, RankedTensorType type,
|
||||
ArrayRef<unsigned> numCTAsEachRep,
|
||||
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
||||
ArrayRef<unsigned> paddedRepShape,
|
||||
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
|
||||
Value smemBase) const {
|
||||
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
||||
auto layout = type.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto layout = type.getEncoding();
|
||||
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
||||
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
|
||||
auto rank = type.getRank();
|
||||
auto sizePerThread = layout.getSizePerThread();
|
||||
auto sizePerThread = getSizePerThread(layout);
|
||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
SmallVector<unsigned> numCTAs(rank);
|
||||
SmallVector<unsigned> shapePerCTA(rank);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
shapePerCTA[d] = layout.getSizePerThread()[d] *
|
||||
layout.getThreadsPerWarp()[d] *
|
||||
layout.getWarpsPerCTA()[d];
|
||||
shapePerCTA[d] = getShapePerCTA(layout, d);
|
||||
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
||||
}
|
||||
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
|
||||
auto multiDimOffsetFirstElem =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, layout, type.getShape());
|
||||
SmallVector<Value> multiDimOffsetFirstElem;
|
||||
Value mmaGrpId;
|
||||
Value mmaGrpIdP8;
|
||||
Value mmaThreadIdInGrpM2;
|
||||
Value mmaThreadIdInGrpM2P1;
|
||||
if (blockedLayout) {
|
||||
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
||||
loc, rewriter, blockedLayout, type.getShape());
|
||||
} else if (mmaLayout) {
|
||||
// TODO: simplify these
|
||||
auto cast = rewriter.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{llvmIndexTy},
|
||||
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
|
||||
Value threadId = cast.getResult(0);
|
||||
Value warpSize = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), 32);
|
||||
Value laneId = rewriter.create<LLVM::URemOp>(loc, threadId, warpSize);
|
||||
Value fourVal = createIndexConst(rewriter, loc, 4);
|
||||
mmaGrpId = rewriter.create<LLVM::UDivOp>(loc, laneId, fourVal);
|
||||
mmaGrpIdP8 = rewriter.create<LLVM::AddOp>(
|
||||
loc, mmaGrpId, createIndexConst(rewriter, loc, 8));
|
||||
Value mmaThreadIdInGrp =
|
||||
rewriter.create<LLVM::URemOp>(loc, laneId, fourVal);
|
||||
mmaThreadIdInGrpM2 = rewriter.create<LLVM::MulOp>(
|
||||
loc, mmaThreadIdInGrp, createIndexConst(rewriter, loc, 2));
|
||||
mmaThreadIdInGrpM2P1 = rewriter.create<LLVM::AddOp>(
|
||||
loc, mmaThreadIdInGrpM2, createIndexConst(rewriter, loc, 1));
|
||||
}
|
||||
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
||||
auto multiDimCTAInRepId =
|
||||
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
|
||||
@@ -1605,18 +1627,27 @@ private:
|
||||
// TODO: This is actually redundant index calculation, we should
|
||||
// consider of caching the index calculation result in case
|
||||
// of performance issue observed.
|
||||
// for (unsigned elemId = linearCTAId * accumSizePerThread;
|
||||
// elemId < (linearCTAId + 1) * accumSizePerThread; elemId += vec) {
|
||||
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
||||
auto multiDimElemId =
|
||||
getMultiDimIndex<unsigned>(elemId, layout.getSizePerThread());
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
multiDimOffset[d] = add(
|
||||
multiDimOffsetFirstElem[d],
|
||||
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
|
||||
multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||
multiDimElemId[d]));
|
||||
if (blockedLayout) {
|
||||
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||
elemId, blockedLayout.getSizePerThread());
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
multiDimOffset[d] = rewriter.create<LLVM::AddOp>(
|
||||
loc, multiDimOffsetFirstElem[d],
|
||||
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
|
||||
multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||
multiDimElemId[d]));
|
||||
}
|
||||
} else if (mmaLayout) {
|
||||
assert(rank == 2);
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout ver1 not implemented yet");
|
||||
multiDimOffset[0] = elemId < 2 ? mmaGrpId : mmaGrpIdP8;
|
||||
multiDimOffset[1] =
|
||||
elemId % 2 == 0 ? mmaThreadIdInGrpM2 : mmaThreadIdInGrpM2P1;
|
||||
} else {
|
||||
assert(0 && "unexpected layout in processReplica");
|
||||
}
|
||||
Value offset =
|
||||
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
|
||||
@@ -2517,16 +2548,14 @@ public:
|
||||
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||
Attribute layout = type.getEncoding();
|
||||
if (layout && (layout.isa<BlockedEncodingAttr>() ||
|
||||
layout.isa<SliceEncodingAttr>())) {
|
||||
if (layout &&
|
||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||
layout.isa<MmaEncodingAttr>())) {
|
||||
unsigned numElementsPerThread =
|
||||
getElemsPerThread(layout, type.getShape());
|
||||
SmallVector<Type, 4> types(numElementsPerThread,
|
||||
convertType(type.getElementType()));
|
||||
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
||||
} else if (auto mma_layout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
|
||||
// TODO: Not implemented
|
||||
return type;
|
||||
} else if (auto shared_layout =
|
||||
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||
|
Reference in New Issue
Block a user