[Triton-MLIR][Backend] Support layout conversion between mmaLayout and blockedLayout (#693)

This commit is contained in:
goostavz
2022-09-27 11:58:47 +08:00
committed by GitHub
parent 1e91ed30d0
commit 61b61755e5
6 changed files with 205 additions and 104 deletions

View File

@@ -22,8 +22,12 @@ namespace gpu {
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape); unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
SmallVector<unsigned> getSizePerThread(Attribute layout);
unsigned getShapePerCTA(const Attribute &layout, unsigned d); unsigned getShapePerCTA(const Attribute &layout, unsigned d);
SmallVector<unsigned> getOrder(const Attribute &layout);
} // namespace gpu } // namespace gpu
} // namespace triton } // namespace triton
} // namespace mlir } // namespace mlir

View File

@@ -11,6 +11,7 @@
#include <numeric> #include <numeric>
using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr;
@@ -32,39 +33,40 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
"Unexpect layout in getScratchConfigForCvtLayout()"); "Unexpect layout in getScratchConfigForCvtLayout()");
unsigned rank = dstTy.getRank(); unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank); SmallVector<unsigned> paddedRepShape(rank);
// TODO: move to TritonGPUAttrDefs.h.inc auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
auto getShapePerCTA = [&](const Attribute &layout, unsigned d) -> unsigned { auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) { auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
return blockedLayout.getSizePerThread()[d] * auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
blockedLayout.getThreadsPerWarp()[d] * assert((srcBlockedLayout || srcMmaLayout) &&
blockedLayout.getWarpsPerCTA()[d]; "Unexpected srcLayout in getScratchConfigForCvtLayout");
} else { assert((dstBlockedLayout || dstMmaLayout) &&
assert(0 && "Unimplemented usage of getShapePerCTA"); "Unexpected dstLayout in getScratchConfigForCvtLayout");
return 0; assert(!(srcMmaLayout && dstMmaLayout) &&
} "Unexpected mma -> mma layout conversion");
}; auto inOrd =
if (srcLayout.isa<BlockedEncodingAttr>() && srcMmaLayout ? dstBlockedLayout.getOrder() : srcBlockedLayout.getOrder();
dstLayout.isa<BlockedEncodingAttr>()) { auto outOrd =
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>(); dstMmaLayout ? srcBlockedLayout.getOrder() : dstBlockedLayout.getOrder();
auto dstBlockedLayout = dstLayout.cast<BlockedEncodingAttr>(); unsigned srcContigPerThread =
auto inOrd = srcBlockedLayout.getOrder(); srcBlockedLayout ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 2;
auto outOrd = dstBlockedLayout.getOrder(); unsigned dstContigPerThread =
dstBlockedLayout ? dstBlockedLayout.getSizePerThread()[outOrd[0]] : 2;
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means // TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization. // that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
: inOrd[0] == 0 ? 1 outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
: srcBlockedLayout.getSizePerThread()[inOrd[0]];
outVec =
outOrd[0] == 0 ? 1 : dstBlockedLayout.getSizePerThread()[outOrd[0]];
unsigned pad = std::max(inVec, outVec); unsigned pad = std::max(inVec, outVec);
for (unsigned d = 0; d < rank; ++d) { for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] = std::max( paddedRepShape[d] = std::max(
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)), std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
std::min<unsigned>(dstTy.getShape()[d], std::min<unsigned>(dstTy.getShape()[d], getShapePerCTA(dstLayout, d)));
getShapePerCTA(dstLayout, d)));
} }
paddedRepShape[outOrd[0]] += pad; unsigned paddedDim = 1;
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
paddedDim = dstBlockedLayout.getOrder()[0];
} }
paddedRepShape[paddedDim] += pad;
return paddedRepShape; return paddedRepShape;
} }

View File

@@ -29,7 +29,9 @@ using namespace mlir;
using namespace mlir::triton; using namespace mlir::triton;
using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr;
@@ -350,6 +352,13 @@ public:
return threadId; return threadId;
} }
Value createIndexConst(ConversionPatternRewriter &rewriter, Location loc,
int64_t value) const {
return rewriter.create<LLVM::ConstantOp>(
loc, this->getTypeConverter()->getIndexType(),
rewriter.getIntegerAttr(rewriter.getIndexType(), value));
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter, SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear, Location loc, Value linear,
ArrayRef<unsigned> shape, ArrayRef<unsigned> shape,
@@ -423,7 +432,7 @@ public:
auto order = blocked_layout.getOrder(); auto order = blocked_layout.getOrder();
unsigned rank = shape.size(); unsigned rank = shape.size();
// step 1, delinearize threadId to get the base index // delinearize threadId to get the base index
SmallVector<Value> multiDimWarpId = SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order); delinearize(rewriter, loc, warpId, warpsPerCTA, order);
SmallVector<Value> multiDimThreadId = SmallVector<Value> multiDimThreadId =
@@ -455,6 +464,13 @@ public:
return multiDimBase; return multiDimBase;
} }
SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
// ongoing
}
SmallVector<SmallVector<Value>> emitIndices(Location loc, SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b, ConversionPatternRewriter &b,
const Attribute &layout, const Attribute &layout,
@@ -1459,9 +1475,11 @@ public:
auto dstTy = dst.getType().cast<RankedTensorType>(); auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding(); Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding();
if ((!srcLayout.isa<BlockedEncodingAttr>()) || if ((!srcLayout.isa<BlockedEncodingAttr>() &&
(!dstLayout.isa<BlockedEncodingAttr>())) { !srcLayout.isa<MmaEncodingAttr>()) ||
// TODO: not implemented (!dstLayout.isa<BlockedEncodingAttr>() &&
!dstLayout.isa<MmaEncodingAttr>())) {
// TODO: to be implemented
return failure(); return failure();
} }
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
@@ -1471,31 +1489,6 @@ public:
auto shape = dstTy.getShape(); auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank(); unsigned rank = dstTy.getRank();
auto getContigPerThread = [&](const Attribute &layout,
unsigned d) -> unsigned {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getSizePerThread()[d];
} else {
assert(0 && "Unimplemented usage of getContigPerThread");
return 0;
}
};
auto getAccumElemsPerThread = [&](const Attribute &layout) -> unsigned {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return product<unsigned>(blockedLayout.getSizePerThread());
} else {
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
return 0;
}
};
auto getOrder = [&](const Attribute &layout) -> ArrayRef<unsigned> {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getOrder();
} else {
assert(0 && "Unimplemented usage of getAccumElemsPerThread");
return {};
}
};
SmallVector<unsigned> numReplicates(rank); SmallVector<unsigned> numReplicates(rank);
SmallVector<unsigned> inNumCTAsEachRep(rank); SmallVector<unsigned> inNumCTAsEachRep(rank);
SmallVector<unsigned> outNumCTAsEachRep(rank); SmallVector<unsigned> outNumCTAsEachRep(rank);
@@ -1517,7 +1510,6 @@ public:
} }
// Potentially we need to store for multiple CTAs in this replication // Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates); unsigned accumNumReplicates = product<unsigned>(numReplicates);
unsigned accumInSizePerThread = getAccumElemsPerThread(srcLayout);
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape()); unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter); auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
unsigned inVec = 0; unsigned inVec = 0;
@@ -1530,19 +1522,21 @@ public:
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates); auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
rewriter.create<mlir::gpu::BarrierOp>(loc); rewriter.create<mlir::gpu::BarrierOp>(loc);
if (auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>()) { if (srcLayout.isa<BlockedEncodingAttr>() ||
processReplicaBlocked(loc, rewriter, /*stNotRd*/ true, srcTy, srcLayout.isa<MmaEncodingAttr>()) {
inNumCTAsEachRep, multiDimRepId, inVec, processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
paddedRepShape, outOrd, vals, smemBase); multiDimRepId, inVec, paddedRepShape, outOrd, vals,
smemBase);
} else { } else {
assert(0 && "ConvertLayout with input layout not implemented"); assert(0 && "ConvertLayout with input layout not implemented");
return failure(); return failure();
} }
rewriter.create<mlir::gpu::BarrierOp>(loc); rewriter.create<mlir::gpu::BarrierOp>(loc);
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) { if (dstLayout.isa<BlockedEncodingAttr>() ||
processReplicaBlocked(loc, rewriter, /*stNotRd*/ false, dstTy, dstLayout.isa<MmaEncodingAttr>()) {
outNumCTAsEachRep, multiDimRepId, outVec, processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
paddedRepShape, outOrd, outVals, smemBase); outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape,
outOrd, outVals, smemBase);
} else { } else {
assert(0 && "ConvertLayout with output layout not implemented"); assert(0 && "ConvertLayout with output layout not implemented");
return failure(); return failure();
@@ -1568,30 +1562,58 @@ private:
return result; return result;
}; };
void processReplicaBlocked(Location loc, ConversionPatternRewriter &rewriter, // shared memory access for blocked or mma layout
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type, bool stNotRd, RankedTensorType type,
ArrayRef<unsigned> numCTAsEachRep, ArrayRef<unsigned> numCTAsEachRep,
ArrayRef<unsigned> multiDimRepId, unsigned vec, ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> outOrd, ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
SmallVector<Value> &vals, Value smemBase) const { Value smemBase) const {
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep); unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding().cast<BlockedEncodingAttr>(); auto layout = type.getEncoding();
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
auto rank = type.getRank(); auto rank = type.getRank();
auto sizePerThread = layout.getSizePerThread(); auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread); auto accumSizePerThread = product<unsigned>(sizePerThread);
auto llvmIndexTy = getTypeConverter()->getIndexType(); auto llvmIndexTy = getTypeConverter()->getIndexType();
SmallVector<unsigned> numCTAs(rank); SmallVector<unsigned> numCTAs(rank);
SmallVector<unsigned> shapePerCTA(rank); SmallVector<unsigned> shapePerCTA(rank);
for (unsigned d = 0; d < rank; ++d) { for (unsigned d = 0; d < rank; ++d) {
shapePerCTA[d] = layout.getSizePerThread()[d] * shapePerCTA[d] = getShapePerCTA(layout, d);
layout.getThreadsPerWarp()[d] *
layout.getWarpsPerCTA()[d];
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]); numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
} }
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType()); auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
auto multiDimOffsetFirstElem = SmallVector<Value> multiDimOffsetFirstElem;
emitBaseIndexForBlockedLayout(loc, rewriter, layout, type.getShape()); Value mmaGrpId;
Value mmaGrpIdP8;
Value mmaThreadIdInGrpM2;
Value mmaThreadIdInGrpM2P1;
if (blockedLayout) {
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
loc, rewriter, blockedLayout, type.getShape());
} else if (mmaLayout) {
// TODO: simplify these
auto cast = rewriter.create<UnrealizedConversionCastOp>(
loc, TypeRange{llvmIndexTy},
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
Value threadId = cast.getResult(0);
Value warpSize = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), 32);
Value laneId = rewriter.create<LLVM::URemOp>(loc, threadId, warpSize);
Value fourVal = createIndexConst(rewriter, loc, 4);
mmaGrpId = rewriter.create<LLVM::UDivOp>(loc, laneId, fourVal);
mmaGrpIdP8 = rewriter.create<LLVM::AddOp>(
loc, mmaGrpId, createIndexConst(rewriter, loc, 8));
Value mmaThreadIdInGrp =
rewriter.create<LLVM::URemOp>(loc, laneId, fourVal);
mmaThreadIdInGrpM2 = rewriter.create<LLVM::MulOp>(
loc, mmaThreadIdInGrp, createIndexConst(rewriter, loc, 2));
mmaThreadIdInGrpM2P1 = rewriter.create<LLVM::AddOp>(
loc, mmaThreadIdInGrpM2, createIndexConst(rewriter, loc, 1));
}
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
auto multiDimCTAInRepId = auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep); getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
@@ -1605,19 +1627,28 @@ private:
// TODO: This is actually redundant index calculation, we should // TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case // consider of caching the index calculation result in case
// of performance issue observed. // of performance issue observed.
// for (unsigned elemId = linearCTAId * accumSizePerThread;
// elemId < (linearCTAId + 1) * accumSizePerThread; elemId += vec) {
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
auto multiDimElemId =
getMultiDimIndex<unsigned>(elemId, layout.getSizePerThread());
SmallVector<Value> multiDimOffset(rank); SmallVector<Value> multiDimOffset(rank);
if (blockedLayout) {
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, blockedLayout.getSizePerThread());
for (unsigned d = 0; d < rank; ++d) { for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = add( multiDimOffset[d] = rewriter.create<LLVM::AddOp>(
multiDimOffsetFirstElem[d], loc, multiDimOffsetFirstElem[d],
createIndexAttrConstant(rewriter, loc, llvmIndexTy, createIndexAttrConstant(rewriter, loc, llvmIndexTy,
multiDimCTAInRepId[d] * shapePerCTA[d] + multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d])); multiDimElemId[d]));
} }
} else if (mmaLayout) {
assert(rank == 2);
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout ver1 not implemented yet");
multiDimOffset[0] = elemId < 2 ? mmaGrpId : mmaGrpIdP8;
multiDimOffset[1] =
elemId % 2 == 0 ? mmaThreadIdInGrpM2 : mmaThreadIdInGrpM2P1;
} else {
assert(0 && "unexpected layout in processReplica");
}
Value offset = Value offset =
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd), linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
reorder<unsigned>(paddedRepShape, outOrd)); reorder<unsigned>(paddedRepShape, outOrd));
@@ -2517,16 +2548,14 @@ public:
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) { llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
Attribute layout = type.getEncoding(); Attribute layout = type.getEncoding();
if (layout && (layout.isa<BlockedEncodingAttr>() || if (layout &&
layout.isa<SliceEncodingAttr>())) { (layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
layout.isa<MmaEncodingAttr>())) {
unsigned numElementsPerThread = unsigned numElementsPerThread =
getElemsPerThread(layout, type.getShape()); getElemsPerThread(layout, type.getShape());
SmallVector<Type, 4> types(numElementsPerThread, SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType())); convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(&getContext(), types); return LLVM::LLVMStructType::getLiteral(&getContext(), types);
} else if (auto mma_layout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
// TODO: Not implemented
return type;
} else if (auto shared_layout = } else if (auto shared_layout =
layout.dyn_cast_or_null<SharedEncodingAttr>()) { layout.dyn_cast_or_null<SharedEncodingAttr>()) {
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3); return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);

View File

@@ -58,17 +58,56 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
} }
} }
SmallVector<unsigned> getSizePerThread(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
return SmallVector<unsigned>{2, 2};
} else {
assert(0 && "getSizePerThread not implemented");
return {};
}
}
unsigned getShapePerCTA(const Attribute &layout, unsigned d) { unsigned getShapePerCTA(const Attribute &layout, unsigned d) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) { if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getSizePerThread()[d] * return blockedLayout.getSizePerThread()[d] *
blockedLayout.getThreadsPerWarp()[d] * blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]; blockedLayout.getWarpsPerCTA()[d];
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout version = 1 is not implemented yet");
assert(d < 2 && "Unexpected usage of getShapePerCTA");
if (d == 0) {
return 16 * mmaLayout.getWarpsPerCTA()[0];
} else {
// d == 1
return 8 * mmaLayout.getWarpsPerCTA()[1];
}
} else { } else {
assert(0 && "Unimplemented usage of getShapePerCTA"); assert(0 && "Unimplemented usage of getShapePerCTA");
return 0; return 0;
} }
}; };
SmallVector<unsigned> getOrder(const Attribute &layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getOrder().begin(),
blockedLayout.getOrder().end());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
return SmallVector<unsigned>{1, 0};
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
} else {
assert(0 && "Unimplemented usage of getOrder");
return {};
}
};
} // namespace gpu } // namespace gpu
} // namespace triton } // namespace triton
} // namespace mlir } // namespace mlir
@@ -177,9 +216,11 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
} }
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const { unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
int threads = product(getWarpsPerCTA()); size_t rank = shape.size();
int numElem = product(shape); assert(rank == 2 && "Unexpected rank of mma layout");
return numElem / threads; unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
return elemsCol * elemsRow;
} }
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const { unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {

View File

@@ -1127,6 +1127,10 @@ def default_cache_dir():
return os.path.join(os.environ["HOME"], ".triton", "cache") return os.path.join(os.environ["HOME"], ".triton", "cache")
def default_cuda_dir():
return os.path.join("/usr", "local", "cuda")
class CacheManager: class CacheManager:
def __init__(self, key): def __init__(self, key):
@@ -1181,7 +1185,8 @@ def quiet():
def _build(name, src, srcdir): def _build(name, src, srcdir):
cuda_lib_dir = libcuda_dir() cuda_lib_dir = libcuda_dir()
cu_include_dir = "/usr/local/cuda/include" cuda_path = os.environ.get('CUDA_PATH', default_cuda_dir())
cu_include_dir = os.path.join(cuda_path, "include")
suffix = sysconfig.get_config_var('EXT_SUFFIX') suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible # try to avoid setuptools if possible

View File

@@ -486,7 +486,6 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
} }
} }
// TODO: problems in MLIR's parser on slice layout // TODO: problems in MLIR's parser on slice layout
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> // #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
// module attributes {"triton_gpu.num-warps" = 1 : i32} { // module attributes {"triton_gpu.num-warps" = 1 : i32} {
@@ -495,3 +494,24 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// return // return
// } // }
// } // }
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<2560 x i8>
// CHECK-LABEL: convert_layout_mma_block
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
return
}
}