[Triton-MLIR][Backend] Support layout conversion between mmaLayout and blockedLayout (#693)
This commit is contained in:
@@ -22,8 +22,12 @@ namespace gpu {
|
|||||||
|
|
||||||
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
|
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
||||||
|
|
||||||
unsigned getShapePerCTA(const Attribute &layout, unsigned d);
|
unsigned getShapePerCTA(const Attribute &layout, unsigned d);
|
||||||
|
|
||||||
|
SmallVector<unsigned> getOrder(const Attribute &layout);
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -11,6 +11,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||||
|
using ::mlir::triton::gpu::getShapePerCTA;
|
||||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||||
|
|
||||||
@@ -32,39 +33,40 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
|||||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||||
unsigned rank = dstTy.getRank();
|
unsigned rank = dstTy.getRank();
|
||||||
SmallVector<unsigned> paddedRepShape(rank);
|
SmallVector<unsigned> paddedRepShape(rank);
|
||||||
// TODO: move to TritonGPUAttrDefs.h.inc
|
auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
|
||||||
auto getShapePerCTA = [&](const Attribute &layout, unsigned d) -> unsigned {
|
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
|
||||||
return blockedLayout.getSizePerThread()[d] *
|
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
|
||||||
blockedLayout.getThreadsPerWarp()[d] *
|
assert((srcBlockedLayout || srcMmaLayout) &&
|
||||||
blockedLayout.getWarpsPerCTA()[d];
|
"Unexpected srcLayout in getScratchConfigForCvtLayout");
|
||||||
} else {
|
assert((dstBlockedLayout || dstMmaLayout) &&
|
||||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
"Unexpected dstLayout in getScratchConfigForCvtLayout");
|
||||||
return 0;
|
assert(!(srcMmaLayout && dstMmaLayout) &&
|
||||||
}
|
"Unexpected mma -> mma layout conversion");
|
||||||
};
|
auto inOrd =
|
||||||
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder();
|
||||||
dstLayout.isa<BlockedEncodingAttr>()) {
|
auto outOrd =
|
||||||
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
|
dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder();
|
||||||
auto dstBlockedLayout = dstLayout.cast<BlockedEncodingAttr>();
|
unsigned srcContigPerThread =
|
||||||
auto inOrd = srcBlockedLayout.getOrder();
|
srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2;
|
||||||
auto outOrd = dstBlockedLayout.getOrder();
|
unsigned dstContigPerThread =
|
||||||
|
dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2;
|
||||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||||
// that we cannot do vectorization.
|
// that we cannot do vectorization.
|
||||||
inVec = outOrd[0] == 0 ? 1
|
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||||
: inOrd[0] == 0 ? 1
|
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||||
: srcBlockedLayout.getSizePerThread()[inOrd[0]];
|
|
||||||
outVec =
|
|
||||||
outOrd[0] == 0 ? 1 : dstBlockedLayout.getSizePerThread()[outOrd[0]];
|
|
||||||
unsigned pad = std::max(inVec, outVec);
|
unsigned pad = std::max(inVec, outVec);
|
||||||
for (unsigned d = 0; d < rank; ++d) {
|
for (unsigned d = 0; d < rank; ++d) {
|
||||||
paddedRepShape[d] = std::max(
|
paddedRepShape[d] = std::max(
|
||||||
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
|
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
|
||||||
std::min<unsigned>(dstTy.getShape()[d],
|
std::min<unsigned>(dstTy.getShape()[d], getShapePerCTA(dstLayout, d)));
|
||||||
getShapePerCTA(dstLayout, d)));
|
|
||||||
}
|
}
|
||||||
paddedRepShape[outOrd[0]] += pad;
|
unsigned paddedDim = 1;
|
||||||
|
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
paddedDim = dstBlockedLayout.getOrder()[0];
|
||||||
}
|
}
|
||||||
|
paddedRepShape[paddedDim] += pad;
|
||||||
return paddedRepShape;
|
return paddedRepShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -29,7 +29,9 @@ using namespace mlir;
|
|||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||||
using ::mlir::triton::gpu::getElemsPerThread;
|
using ::mlir::triton::gpu::getElemsPerThread;
|
||||||
|
using ::mlir::triton::gpu::getOrder;
|
||||||
using ::mlir::triton::gpu::getShapePerCTA;
|
using ::mlir::triton::gpu::getShapePerCTA;
|
||||||
|
using ::mlir::triton::gpu::getSizePerThread;
|
||||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||||
@@ -350,6 +352,13 @@ public:
|
|||||||
return threadId;
|
return threadId;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value createIndexConst(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
int64_t value) const {
|
||||||
|
return rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, this->getTypeConverter()->getIndexType(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIndexType(), value));
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
||||||
Location loc, Value linear,
|
Location loc, Value linear,
|
||||||
ArrayRef<unsigned> shape,
|
ArrayRef<unsigned> shape,
|
||||||
@@ -423,7 +432,7 @@ public:
|
|||||||
auto order = blocked_layout.getOrder();
|
auto order = blocked_layout.getOrder();
|
||||||
unsigned rank = shape.size();
|
unsigned rank = shape.size();
|
||||||
|
|
||||||
// step 1, delinearize threadId to get the base index
|
// delinearize threadId to get the base index
|
||||||
SmallVector<Value> multiDimWarpId =
|
SmallVector<Value> multiDimWarpId =
|
||||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||||
SmallVector<Value> multiDimThreadId =
|
SmallVector<Value> multiDimThreadId =
|
||||||
@@ -455,6 +464,13 @@ public:
|
|||||||
return multiDimBase;
|
return multiDimBase;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<Value>
|
||||||
|
emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
||||||
|
const MmaEncodingAttr &mmaLayout,
|
||||||
|
ArrayRef<int64_t> shape) const {
|
||||||
|
// ongoing
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||||
ConversionPatternRewriter &b,
|
ConversionPatternRewriter &b,
|
||||||
const Attribute &layout,
|
const Attribute &layout,
|
||||||
@@ -1459,9 +1475,11 @@ public:
|
|||||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||||
Attribute srcLayout = srcTy.getEncoding();
|
Attribute srcLayout = srcTy.getEncoding();
|
||||||
Attribute dstLayout = dstTy.getEncoding();
|
Attribute dstLayout = dstTy.getEncoding();
|
||||||
if ((!srcLayout.isa<BlockedEncodingAttr>()) ||
|
if ((!srcLayout.isa<BlockedEncodingAttr>() &&
|
||||||
(!dstLayout.isa<BlockedEncodingAttr>())) {
|
!srcLayout.isa<MmaEncodingAttr>()) ||
|
||||||
// TODO: not implemented
|
(!dstLayout.isa<BlockedEncodingAttr>() &&
|
||||||
|
!dstLayout.isa<MmaEncodingAttr>())) {
|
||||||
|
// TODO: to be implemented
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||||
@@ -1471,31 +1489,6 @@ public:
|
|||||||
|
|
||||||
auto shape = dstTy.getShape();
|
auto shape = dstTy.getShape();
|
||||||
unsigned rank = dstTy.getRank();
|
unsigned rank = dstTy.getRank();
|
||||||
auto getContigPerThread = [&](const Attribute &layout,
|
|
||||||
unsigned d) -> unsigned {
|
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
||||||
return blockedLayout.getSizePerThread()[d];
|
|
||||||
} else {
|
|
||||||
assert(0 && "Unimplemented usage of getContigPerThread");
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
auto getAccumElemsPerThread = [&](const Attribute &layout) -> unsigned {
|
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
||||||
return product<unsigned>(blockedLayout.getSizePerThread());
|
|
||||||
} else {
|
|
||||||
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
auto getOrder = [&](const Attribute &layout) -> ArrayRef<unsigned> {
|
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
|
||||||
return blockedLayout.getOrder();
|
|
||||||
} else {
|
|
||||||
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
SmallVector<unsigned> numReplicates(rank);
|
SmallVector<unsigned> numReplicates(rank);
|
||||||
SmallVector<unsigned> inNumCTAsEachRep(rank);
|
SmallVector<unsigned> inNumCTAsEachRep(rank);
|
||||||
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
||||||
@@ -1517,7 +1510,6 @@ public:
|
|||||||
}
|
}
|
||||||
// Potentially we need to store for multiple CTAs in this replication
|
// Potentially we need to store for multiple CTAs in this replication
|
||||||
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
||||||
unsigned accumInSizePerThread = getAccumElemsPerThread(srcLayout);
|
|
||||||
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
|
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
|
||||||
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
||||||
unsigned inVec = 0;
|
unsigned inVec = 0;
|
||||||
@@ -1530,19 +1522,21 @@ public:
|
|||||||
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
||||||
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
||||||
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
||||||
if (auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>()) {
|
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
||||||
processReplicaBlocked(loc, rewriter, /*stNotRd*/ true, srcTy,
|
srcLayout.isa<MmaEncodingAttr>()) {
|
||||||
inNumCTAsEachRep, multiDimRepId, inVec,
|
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
|
||||||
paddedRepShape, outOrd, vals, smemBase);
|
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
|
||||||
|
smemBase);
|
||||||
} else {
|
} else {
|
||||||
assert(0 && "ConvertLayout with input layout not implemented");
|
assert(0 && "ConvertLayout with input layout not implemented");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
||||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
||||||
processReplicaBlocked(loc, rewriter, /*stNotRd*/ false, dstTy,
|
dstLayout.isa<MmaEncodingAttr>()) {
|
||||||
outNumCTAsEachRep, multiDimRepId, outVec,
|
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
|
||||||
paddedRepShape, outOrd, outVals, smemBase);
|
outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape,
|
||||||
|
outOrd, outVals, smemBase);
|
||||||
} else {
|
} else {
|
||||||
assert(0 && "ConvertLayout with output layout not implemented");
|
assert(0 && "ConvertLayout with output layout not implemented");
|
||||||
return failure();
|
return failure();
|
||||||
@@ -1568,30 +1562,58 @@ private:
|
|||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
|
|
||||||
void processReplicaBlocked(Location loc, ConversionPatternRewriter &rewriter,
|
// shared memory access for blocked or mma layout
|
||||||
|
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
|
||||||
bool stNotRd, RankedTensorType type,
|
bool stNotRd, RankedTensorType type,
|
||||||
ArrayRef<unsigned> numCTAsEachRep,
|
ArrayRef<unsigned> numCTAsEachRep,
|
||||||
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
||||||
ArrayRef<unsigned> paddedRepShape,
|
ArrayRef<unsigned> paddedRepShape,
|
||||||
ArrayRef<unsigned> outOrd,
|
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
|
||||||
SmallVector<Value> &vals, Value smemBase) const {
|
Value smemBase) const {
|
||||||
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
||||||
auto layout = type.getEncoding().cast<BlockedEncodingAttr>();
|
auto layout = type.getEncoding();
|
||||||
|
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
||||||
|
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
|
||||||
auto rank = type.getRank();
|
auto rank = type.getRank();
|
||||||
auto sizePerThread = layout.getSizePerThread();
|
auto sizePerThread = getSizePerThread(layout);
|
||||||
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
||||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||||
SmallVector<unsigned> numCTAs(rank);
|
SmallVector<unsigned> numCTAs(rank);
|
||||||
SmallVector<unsigned> shapePerCTA(rank);
|
SmallVector<unsigned> shapePerCTA(rank);
|
||||||
for (unsigned d = 0; d < rank; ++d) {
|
for (unsigned d = 0; d < rank; ++d) {
|
||||||
shapePerCTA[d] = layout.getSizePerThread()[d] *
|
shapePerCTA[d] = getShapePerCTA(layout, d);
|
||||||
layout.getThreadsPerWarp()[d] *
|
|
||||||
layout.getWarpsPerCTA()[d];
|
|
||||||
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
||||||
}
|
}
|
||||||
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
|
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
|
||||||
auto multiDimOffsetFirstElem =
|
SmallVector<Value> multiDimOffsetFirstElem;
|
||||||
emitBaseIndexForBlockedLayout(loc, rewriter, layout, type.getShape());
|
Value mmaGrpId;
|
||||||
|
Value mmaGrpIdP8;
|
||||||
|
Value mmaThreadIdInGrpM2;
|
||||||
|
Value mmaThreadIdInGrpM2P1;
|
||||||
|
if (blockedLayout) {
|
||||||
|
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
||||||
|
loc, rewriter, blockedLayout, type.getShape());
|
||||||
|
} else if (mmaLayout) {
|
||||||
|
// TODO: simplify these
|
||||||
|
auto cast = rewriter.create<UnrealizedConversionCastOp>(
|
||||||
|
loc, TypeRange{llvmIndexTy},
|
||||||
|
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
|
||||||
|
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
|
||||||
|
Value threadId = cast.getResult(0);
|
||||||
|
Value warpSize = createIndexAttrConstant(
|
||||||
|
rewriter, loc, this->getTypeConverter()->getIndexType(), 32);
|
||||||
|
Value laneId = rewriter.create<LLVM::URemOp>(loc, threadId, warpSize);
|
||||||
|
Value fourVal = createIndexConst(rewriter, loc, 4);
|
||||||
|
mmaGrpId = rewriter.create<LLVM::UDivOp>(loc, laneId, fourVal);
|
||||||
|
mmaGrpIdP8 = rewriter.create<LLVM::AddOp>(
|
||||||
|
loc, mmaGrpId, createIndexConst(rewriter, loc, 8));
|
||||||
|
Value mmaThreadIdInGrp =
|
||||||
|
rewriter.create<LLVM::URemOp>(loc, laneId, fourVal);
|
||||||
|
mmaThreadIdInGrpM2 = rewriter.create<LLVM::MulOp>(
|
||||||
|
loc, mmaThreadIdInGrp, createIndexConst(rewriter, loc, 2));
|
||||||
|
mmaThreadIdInGrpM2P1 = rewriter.create<LLVM::AddOp>(
|
||||||
|
loc, mmaThreadIdInGrpM2, createIndexConst(rewriter, loc, 1));
|
||||||
|
}
|
||||||
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
||||||
auto multiDimCTAInRepId =
|
auto multiDimCTAInRepId =
|
||||||
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
|
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
|
||||||
@@ -1605,19 +1627,28 @@ private:
|
|||||||
// TODO: This is actually redundant index calculation, we should
|
// TODO: This is actually redundant index calculation, we should
|
||||||
// consider of caching the index calculation result in case
|
// consider of caching the index calculation result in case
|
||||||
// of performance issue observed.
|
// of performance issue observed.
|
||||||
// for (unsigned elemId = linearCTAId * accumSizePerThread;
|
|
||||||
// elemId < (linearCTAId + 1) * accumSizePerThread; elemId += vec) {
|
|
||||||
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
||||||
auto multiDimElemId =
|
|
||||||
getMultiDimIndex<unsigned>(elemId, layout.getSizePerThread());
|
|
||||||
SmallVector<Value> multiDimOffset(rank);
|
SmallVector<Value> multiDimOffset(rank);
|
||||||
|
if (blockedLayout) {
|
||||||
|
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||||
|
elemId, blockedLayout.getSizePerThread());
|
||||||
for (unsigned d = 0; d < rank; ++d) {
|
for (unsigned d = 0; d < rank; ++d) {
|
||||||
multiDimOffset[d] = add(
|
multiDimOffset[d] = rewriter.create<LLVM::AddOp>(
|
||||||
multiDimOffsetFirstElem[d],
|
loc, multiDimOffsetFirstElem[d],
|
||||||
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
|
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
|
||||||
multiDimCTAInRepId[d] * shapePerCTA[d] +
|
multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||||
multiDimElemId[d]));
|
multiDimElemId[d]));
|
||||||
}
|
}
|
||||||
|
} else if (mmaLayout) {
|
||||||
|
assert(rank == 2);
|
||||||
|
assert(mmaLayout.getVersion() == 2 &&
|
||||||
|
"mmaLayout ver1 not implemented yet");
|
||||||
|
multiDimOffset[0] = elemId < 2 ? mmaGrpId : mmaGrpIdP8;
|
||||||
|
multiDimOffset[1] =
|
||||||
|
elemId % 2 == 0 ? mmaThreadIdInGrpM2 : mmaThreadIdInGrpM2P1;
|
||||||
|
} else {
|
||||||
|
assert(0 && "unexpected layout in processReplica");
|
||||||
|
}
|
||||||
Value offset =
|
Value offset =
|
||||||
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
|
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
|
||||||
reorder<unsigned>(paddedRepShape, outOrd));
|
reorder<unsigned>(paddedRepShape, outOrd));
|
||||||
@@ -2517,16 +2548,14 @@ public:
|
|||||||
|
|
||||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||||
Attribute layout = type.getEncoding();
|
Attribute layout = type.getEncoding();
|
||||||
if (layout && (layout.isa<BlockedEncodingAttr>() ||
|
if (layout &&
|
||||||
layout.isa<SliceEncodingAttr>())) {
|
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||||
|
layout.isa<MmaEncodingAttr>())) {
|
||||||
unsigned numElementsPerThread =
|
unsigned numElementsPerThread =
|
||||||
getElemsPerThread(layout, type.getShape());
|
getElemsPerThread(layout, type.getShape());
|
||||||
SmallVector<Type, 4> types(numElementsPerThread,
|
SmallVector<Type, 4> types(numElementsPerThread,
|
||||||
convertType(type.getElementType()));
|
convertType(type.getElementType()));
|
||||||
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
||||||
} else if (auto mma_layout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
|
|
||||||
// TODO: Not implemented
|
|
||||||
return type;
|
|
||||||
} else if (auto shared_layout =
|
} else if (auto shared_layout =
|
||||||
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
||||||
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||||
|
@@ -58,17 +58,56 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||||
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||||
|
blockedLayout.getSizePerThread().end());
|
||||||
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
assert(mmaLayout.getVersion() == 2 &&
|
||||||
|
"mmaLayout version = 1 is not implemented yet");
|
||||||
|
return SmallVector<unsigned>{2, 2};
|
||||||
|
} else {
|
||||||
|
assert(0 && "getSizePerThread not implemented");
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
unsigned getShapePerCTA(const Attribute &layout, unsigned d) {
|
unsigned getShapePerCTA(const Attribute &layout, unsigned d) {
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
return blockedLayout.getSizePerThread()[d] *
|
return blockedLayout.getSizePerThread()[d] *
|
||||||
blockedLayout.getThreadsPerWarp()[d] *
|
blockedLayout.getThreadsPerWarp()[d] *
|
||||||
blockedLayout.getWarpsPerCTA()[d];
|
blockedLayout.getWarpsPerCTA()[d];
|
||||||
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
assert(mmaLayout.getVersion() == 2 &&
|
||||||
|
"mmaLayout version = 1 is not implemented yet");
|
||||||
|
assert(d < 2 && "Unexpected usage of getShapePerCTA");
|
||||||
|
if (d == 0) {
|
||||||
|
return 16 * mmaLayout.getWarpsPerCTA()[0];
|
||||||
|
} else {
|
||||||
|
// d == 1
|
||||||
|
return 8 * mmaLayout.getWarpsPerCTA()[1];
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
SmallVector<unsigned> getOrder(const Attribute &layout) {
|
||||||
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
|
||||||
|
blockedLayout.getOrder().end());
|
||||||
|
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
return SmallVector<unsigned>{1, 0};
|
||||||
|
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||||
|
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
|
||||||
|
sharedLayout.getOrder().end());
|
||||||
|
} else {
|
||||||
|
assert(0 && "Unimplemented usage of getOrder");
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
@@ -177,9 +216,11 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||||
int threads = product(getWarpsPerCTA());
|
size_t rank = shape.size();
|
||||||
int numElem = product(shape);
|
assert(rank == 2 && "Unexpected rank of mma layout");
|
||||||
return numElem / threads;
|
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
|
||||||
|
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
|
||||||
|
return elemsCol * elemsRow;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||||
|
@@ -1127,6 +1127,10 @@ def default_cache_dir():
|
|||||||
return os.path.join(os.environ["HOME"], ".triton", "cache")
|
return os.path.join(os.environ["HOME"], ".triton", "cache")
|
||||||
|
|
||||||
|
|
||||||
|
def default_cuda_dir():
|
||||||
|
return os.path.join("/usr", "local", "cuda")
|
||||||
|
|
||||||
|
|
||||||
class CacheManager:
|
class CacheManager:
|
||||||
|
|
||||||
def __init__(self, key):
|
def __init__(self, key):
|
||||||
@@ -1181,7 +1185,8 @@ def quiet():
|
|||||||
|
|
||||||
def _build(name, src, srcdir):
|
def _build(name, src, srcdir):
|
||||||
cuda_lib_dir = libcuda_dir()
|
cuda_lib_dir = libcuda_dir()
|
||||||
cu_include_dir = "/usr/local/cuda/include"
|
cuda_path = os.environ.get('CUDA_PATH', default_cuda_dir())
|
||||||
|
cu_include_dir = os.path.join(cuda_path, "include")
|
||||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||||
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
||||||
# try to avoid setuptools if possible
|
# try to avoid setuptools if possible
|
||||||
|
@@ -486,7 +486,6 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// TODO: problems in MLIR's parser on slice layout
|
// TODO: problems in MLIR's parser on slice layout
|
||||||
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||||
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
@@ -495,3 +494,24 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
// return
|
// return
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||||
|
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||||
|
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<2560 x i8>
|
||||||
|
// CHECK-LABEL: convert_layout_mma_block
|
||||||
|
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) {
|
||||||
|
// CHECK: nvvm.barrier0
|
||||||
|
// CHECK: llvm.store
|
||||||
|
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||||
|
// CHECK: llvm.store
|
||||||
|
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
|
||||||
|
// CHECK: nvvm.barrier0
|
||||||
|
// CHECK: llvm.load
|
||||||
|
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
|
||||||
|
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user