[Triton-MLIR][Backend] Support ConvertLayout blocked->shared and a few fixes related with mma(#716)

This commit is contained in:
goostavz
2022-10-03 19:33:25 +08:00
committed by GitHub
parent baba98ad69
commit f9d7f2f126
5 changed files with 458 additions and 227 deletions

View File

@@ -67,8 +67,6 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
builder.getIntegerAttr(ty, value));
}
// Add other specification if needed...
} // namespace
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
@@ -89,8 +87,12 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
#define barrier rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define i32_ty rewriter.getIntegerType(32)
#define vec_ty(type, num) VectorType::get(num, type)
#define void_ty LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__)
// Creator for constant
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
@@ -109,6 +111,18 @@ namespace type = mlir::triton::type;
class TritonGPUToLLVMTypeConverter;
// TODO: keep these before we have better debug log utilities
template <typename T>
void print_array(ArrayRef<T> array, const std::string &str) {
std::cout << str << ": ";
for (const T &e : array)
std::cout << e << ",";
std::cout << std::endl;
}
template <typename T> void print_scalar(const T &e, const std::string &str) {
std::cout << str << ": " << e << std::endl;
}
// FuncOpConversion/FuncOpConversionBase is borrowed from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
// since it is not exposed on header files in mlir v14
@@ -360,6 +374,13 @@ public:
return threadId;
}
Value createIndexConst(ConversionPatternRewriter &rewriter, Location loc,
int64_t value) const {
return rewriter.create<LLVM::ConstantOp>(
loc, this->getTypeConverter()->getIndexType(),
rewriter.getIntegerAttr(rewriter.getIndexType(), value));
}
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
@@ -859,7 +880,7 @@ struct StoreOpConversion
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
}
// TODO(Superjomn) Need to check masks before vectorize the load for all
// TODO(Superjomn) Need to check masks before vectorize the load for
// the values share one predicate? Here assume all the mask values are
// the same.
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
@@ -1060,7 +1081,6 @@ struct LoadOpConversion
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value ptr = op.ptr();
Value mask = op.mask();
Value other = op.other();
@@ -1413,87 +1433,24 @@ public:
LogicalResult
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (srcLayout.isa<BlockedEncodingAttr>() &&
dstLayout.isa<SharedEncodingAttr>()) {
return lowerBlockedToShared(op, adaptor, rewriter);
}
if ((!srcLayout.isa<BlockedEncodingAttr>() &&
!srcLayout.isa<MmaEncodingAttr>()) ||
(!dstLayout.isa<BlockedEncodingAttr>() &&
!dstLayout.isa<MmaEncodingAttr>())) {
// TODO: to be implemented
llvm::errs() << "Unsupported ConvertLayout found";
return failure();
}
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
smemBase = bit_cast(elemPtrTy, smemBase);
auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank();
SmallVector<unsigned> numReplicates(rank);
SmallVector<unsigned> inNumCTAsEachRep(rank);
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
auto srcShapePerCTA = getShapePerCTA(srcLayout);
auto dstShapePerCTA = getShapePerCTA(dstLayout);
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
// TODO: confirm this
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
}
// Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates);
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
unsigned outElems = getElemsPerThread(dstLayout, shape);
auto outOrd = getOrder(dstLayout);
SmallVector<Value> outVals(outElems);
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
rewriter.create<mlir::gpu::BarrierOp>(loc);
if (srcLayout.isa<BlockedEncodingAttr>() ||
srcLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
smemBase);
} else {
assert(0 && "ConvertLayout with input layout not implemented");
return failure();
}
rewriter.create<mlir::gpu::BarrierOp>(loc);
if (dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape,
outOrd, outVals, smemBase);
} else {
assert(0 && "ConvertLayout with output layout not implemented");
return failure();
}
}
SmallVector<Type> types(outElems, llvmElemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
rewriter.replaceOp(op, result);
return success();
return lowerDistributedToDistributed(op, adaptor, rewriter);
}
private:
@@ -1508,122 +1465,334 @@ private:
return result;
};
// shared memory access for blocked or mma layout
// shared memory rd/st for blocked or mma layout with data padding
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type,
ArrayRef<unsigned> numCTAsEachRep,
ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
Value smemBase) const {
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding();
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
auto rank = type.getRank();
auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread);
auto llvmIndexTy = getTypeConverter()->getIndexType();
SmallVector<unsigned> numCTAs(rank);
auto shapePerCTA = getShapePerCTA(layout);
for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
}
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
SmallVector<Value> multiDimOffsetFirstElem;
Value mmaGrpId;
Value mmaGrpIdP8;
Value mmaThreadIdInGrpM2;
Value mmaThreadIdInGrpM2P1;
if (blockedLayout) {
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
loc, rewriter, blockedLayout, type.getShape());
} else if (mmaLayout) {
// TODO: simplify these
auto cast = rewriter.create<UnrealizedConversionCastOp>(
loc, TypeRange{llvmIndexTy},
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
Value threadId = cast.getResult(0);
Value warpSize = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), 32);
Value laneId = rewriter.create<LLVM::URemOp>(loc, threadId, warpSize);
Value fourVal = idx_val(4);
mmaGrpId = rewriter.create<LLVM::UDivOp>(loc, laneId, fourVal);
mmaGrpIdP8 = rewriter.create<LLVM::AddOp>(loc, mmaGrpId, idx_val(8));
Value mmaThreadIdInGrp =
rewriter.create<LLVM::URemOp>(loc, laneId, fourVal);
mmaThreadIdInGrpM2 =
rewriter.create<LLVM::MulOp>(loc, mmaThreadIdInGrp, idx_val(2));
mmaThreadIdInGrpM2P1 =
rewriter.create<LLVM::AddOp>(loc, mmaThreadIdInGrpM2, idx_val(1));
}
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
SmallVector<unsigned> multiDimCTAId(rank);
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
}
Value smemBase) const;
unsigned linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs);
// TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset(rank);
if (blockedLayout) {
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, blockedLayout.getSizePerThread());
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = rewriter.create<LLVM::AddOp>(
loc, multiDimOffsetFirstElem[d],
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
}
} else if (mmaLayout) {
assert(rank == 2);
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout ver1 not implemented yet");
multiDimOffset[0] = elemId < 2 ? mmaGrpId : mmaGrpIdP8;
multiDimOffset[1] =
elemId % 2 == 0 ? mmaThreadIdInGrpM2 : mmaThreadIdInGrpM2P1;
} else {
assert(0 && "unexpected layout in processReplica");
// blocked/mma -> blocked/mma.
// Data padding in shared memory to avoid bank confict.
LogicalResult
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const;
// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
LogicalResult lowerBlockedToShared(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const;
};
void ConvertLayoutOpConversion::processReplica(
Location loc, ConversionPatternRewriter &rewriter, bool stNotRd,
RankedTensorType type, ArrayRef<unsigned> numCTAsEachRep,
ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> outOrd,
SmallVector<Value> &vals, Value smemBase) const {
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding();
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
auto rank = type.getRank();
auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread);
auto llvmIndexTy = getTypeConverter()->getIndexType();
SmallVector<unsigned> numCTAs(rank);
auto shapePerCTA = getShapePerCTA(layout);
for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
}
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
SmallVector<Value> multiDimOffsetFirstElem;
SmallVector<Value> mmaColIdx(2);
SmallVector<Value> mmaRowIdx(2);
if (blockedLayout) {
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
loc, rewriter, blockedLayout, type.getShape());
} else if (mmaLayout) {
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// auto multiDimWarpId =
// delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA());
// TODO: double confirm if its document bug or DotConversion's Bug
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
Value four = idx_val(4);
Value mmaGrpId = udiv(laneId, four);
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
Value mmaThreadIdInGrp = urem(laneId, four);
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2));
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1));
Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16));
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8));
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
}
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
auto multiDimCTAInRepId = getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
SmallVector<unsigned> multiDimCTAId(rank);
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
}
unsigned linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs);
// TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset(rank);
if (blockedLayout) {
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, blockedLayout.getSizePerThread());
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] =
add(multiDimOffsetFirstElem[d],
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
}
Value offset =
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
reorder<unsigned>(paddedRepShape, outOrd));
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
ptr = bit_cast(LLVM::LLVMPointerType::get(vecTy, 3), ptr);
if (stNotRd) {
Value valVec = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned v = 0; v < vec; ++v) {
Value vVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), v);
valVec = insert_element(
vecTy, valVec,
vals[elemId + linearCTAId * accumSizePerThread + v], vVal);
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) {
Value vVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), v);
vals[elemId + linearCTAId * accumSizePerThread + v] =
extract_element(llvmElemTy, valVec, vVal);
}
} else if (mmaLayout) {
assert(rank == 2);
assert(mmaLayout.getVersion() == 2 &&
"mmaLayout ver1 not implemented yet");
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else {
assert(0 && "unexpected layout in processReplica");
}
Value offset =
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
reorder<unsigned>(paddedRepShape, outOrd));
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
ptr = bit_cast(ptr_ty(vecTy, 3), ptr);
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
valVec = insert_element(
vecTy, valVec,
vals[elemId + linearCTAId * accumSizePerThread + v], idx_val(v));
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) {
vals[elemId + linearCTAId * accumSizePerThread + v] =
extract_element(llvmElemTy, valVec, idx_val(v));
}
}
}
}
}
LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bit_cast(elemPtrTy, smemBase);
auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank();
SmallVector<unsigned> numReplicates(rank);
SmallVector<unsigned> inNumCTAsEachRep(rank);
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
auto srcShapePerCTA = getShapePerCTA(srcLayout);
auto dstShapePerCTA = getShapePerCTA(dstLayout);
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
}
// Potentially we need to store for multiple CTAs in this replication
unsigned accumNumReplicates = product<unsigned>(numReplicates);
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
unsigned outElems = getElemsPerThread(dstLayout, shape);
auto outOrd = getOrder(dstLayout);
SmallVector<Value> outVals(outElems);
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
barrier;
if (srcLayout.isa<BlockedEncodingAttr>() ||
srcLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
smemBase);
} else {
assert(0 && "ConvertLayout with input layout not implemented");
return failure();
}
barrier;
if (dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
multiDimRepId, outVec, paddedRepShape, outOrd, outVals,
smemBase);
} else {
assert(0 && "ConvertLayout with output layout not implemented");
return failure();
}
}
SmallVector<Type> types(outElems, llvmElemTy);
Type structTy = struct_ty(getContext(), types);
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
rewriter.replaceOp(op, result);
return success();
};
LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 &&
"Unexpected rank of ConvertLayout(blocked->shared)");
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
auto inOrd = srcBlockedLayout.getOrder();
auto outOrd = dstSharedLayout.getOrder();
unsigned inVec =
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1;
unsigned outVec = dstSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned perPhase = dstSharedLayout.getPerPhase();
unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
auto inVals = getElementsFromStruct(loc, adaptor.src(), numElems, rewriter);
unsigned srcAccumSizeInThreads =
product<unsigned>(srcBlockedLayout.getSizePerThread());
auto elemTy = srcTy.getElementType();
auto wordTy = vec_ty(elemTy, minVec);
// TODO: [goostavz] We should make a cache for the calculation of
// emitBaseIndexForBlockedLayout in case backend compiler not being able to
// optimize that
SmallVector<Value> multiDimOffsetFirstElem =
emitBaseIndexForBlockedLayout(loc, rewriter, srcBlockedLayout, srcShape);
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
// Visit each input value in the order they are placed in inVals
//
// Please note that the order was not awaring of blockLayout.getOrder(),
// thus the adjacent elems may not belong to a same word. This could be
// improved if we update the elements order by emitIndicesForBlockedLayout()
SmallVector<unsigned> wordsInEachRep(2);
wordsInEachRep[0] = inOrd[0] == 0
? srcBlockedLayout.getSizePerThread()[0] / minVec
: srcBlockedLayout.getSizePerThread()[0];
wordsInEachRep[1] = inOrd[0] == 0
? srcBlockedLayout.getSizePerThread()[1]
: srcBlockedLayout.getSizePerThread()[1] / minVec;
Value outVecVal = idx_val(outVec);
Value minVecVal = idx_val(minVec);
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bit_cast(elemPtrTy, smemBase);
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
SmallVector<Value> wordVecs(numWordsEachRep);
for (unsigned i = 0; i < numElems; ++i) {
if (i % srcAccumSizeInThreads == 0) {
// start of a replication
for (unsigned w = 0; w < numWordsEachRep; ++w) {
wordVecs[w] = undef(wordTy);
}
}
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread());
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
unsigned wordVecIdx =
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep);
wordVecs[wordVecIdx] =
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos));
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
// end of replication, store the vectors into shared memory
unsigned linearRepIdx = i / srcAccumSizeInThreads;
auto multiDimRepIdx = getMultiDimIndex<unsigned>(linearRepIdx, reps);
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
++linearWordIdx) {
// step 1: recover the multidim_index from the index of input_elements
auto multiDimWordIdx =
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep);
SmallVector<Value> multiDimIdx(2);
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] +
multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1);
multiDimIdx[0] = add(multiDimOffsetFirstElem[0], idx_val(wordOffset0));
multiDimIdx[1] = add(multiDimOffsetFirstElem[1], idx_val(wordOffset1));
// step 2: do swizzling
Value remained = urem(multiDimIdx[inOrd[0]], outVecVal);
multiDimIdx[inOrd[0]] = udiv(multiDimIdx[inOrd[0]], outVecVal);
Value off_1 = mul(multiDimIdx[inOrd[1]], idx_val(srcShape[inOrd[0]]));
Value phaseId = udiv(multiDimIdx[inOrd[1]], idx_val(perPhase));
phaseId = urem(phaseId, idx_val(maxPhase));
Value off_0 = xor_(multiDimIdx[inOrd[0]], phaseId);
off_0 = mul(off_0, outVecVal);
remained = udiv(remained, minVecVal);
off_0 = add(off_0, mul(remained, minVecVal));
Value offset = add(off_1, off_0);
// step 3: store
Value smemAddr = gep(elemPtrTy, smemBase, offset);
smemAddr = bit_cast(ptr_ty(wordTy, 3), smemAddr);
store(wordVecs[linearWordIdx], smemAddr);
}
}
}
// TODO: double confirm if the Barrier is necessary here
barrier;
rewriter.replaceOp(op, smemBase);
return success();
}
/// ====================== dot codegen begin ==========================
// Data loader for mma.16816 instruction.
@@ -1843,7 +2012,7 @@ public:
if (canUseLdmatrix)
ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]);
else if (elemBytes == 4 && needTrans) // tf32 & trans
else if (elemBytes == 4 && needTrans)
ptrIdx = matIdx[order[0]];
else if (elemBytes == 1 && needTrans)
ptrIdx = matIdx[order[0]] * 4;
@@ -2127,10 +2296,6 @@ struct DotOpConversionHelper {
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
ATensorTy = A.getType().cast<RankedTensorType>();
BTensorTy = B.getType().cast<RankedTensorType>();
DTensorTy = D.getType().cast<RankedTensorType>();
}
// Load SplatLike C which contains a constVal. It simply returns 4 fp32
@@ -2469,7 +2634,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
bool needTrans = kOrder != order[0];
// (a, b) is the coordinate.
auto load = [&, loader, ptrs, offs, needTrans](int a, int b) {
auto load = [=, &vals, &helper, &ld2](int a, int b) {
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
ptrs, helper.getMatType(), helper.getShemPtrTy());
@@ -2490,78 +2655,68 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
};
std::function<void(int, int)> loadA;
std::function<void(int, int)> loadB = getLoadMatrixFn(
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
if (aTensorTy.getEncoding()
.dyn_cast<SharedEncodingAttr>()) { // load from smem
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
// load from smem
loadA = getLoadMatrixFn(
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
} else if (auto blockedLayout =
aTensorTy.getEncoding()
.dyn_cast<BlockedEncodingAttr>()) { // load from registers,
// used in gemm fuse
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
// load from registers, used in gemm fuse
// TODO(Superjomn) Port the logic.
assert(false && "Loading A from register is not supported yet.");
} else {
assert(false && "A's layout is not supported.");
}
const unsigned mStride = numRepN * 2;
SmallVector<Value> fc(numRepM * mStride + numRepN * 2);
std::function<void(int, int)> loadB = getLoadMatrixFn(
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
const int fcSize = 4 * numRepM * numRepN;
SmallVector<Value> fc(fcSize);
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
// shared layout or blocked layout, we will support them by expanding
// convert_layout.
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
assert(hc.size() == 4UL && "Only splat-like C is supported now");
for (int i = 0; i < fc.size(); i++)
fc[i] = hc[0];
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
unsigned colsPerThread = numRepN * 2;
PTXBuilder builder;
auto &mma = *builder.create(helper.getMmaInstr().str());
auto retArgs = builder.newListOperand(4, "=r");
auto aArgs = builder.newListOperand({
{ha[{m, k}], "r"},
{ha[{m + 1, k}], "r"},
{ha[{m, k + 1}], "r"},
{ha[{m + 1, k + 1}], "r"},
});
auto bArgs =
builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}});
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
// shared layout or blocked layout, we will support them by expanding
// convert_layout.
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
assert(hc.size() == 4UL && "Only splat-like C is supported now");
auto cArgs = builder.newListOperand();
for (int i = 0; i < hc.size(); ++i) {
cArgs->listAppend(builder.newOperand(
hc[i], std::to_string(i))); // reuse the output registers
for (int i = 0; i < 4; ++i) {
cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i],
std::to_string(i)));
// reuse the output registers
}
mma(retArgs, aArgs, bArgs, cArgs);
Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType());
auto getIntAttr = [&](int v) {
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
};
fc[(m + 0) * mStride + (n * 2 + 0)] =
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(0));
fc[(m + 0) * mStride + (n * 2 + 1)] =
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(1));
fc[(m + 1) * mStride + (n * 2 + 0)] =
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(2));
fc[(m + 1) * mStride + (n * 2 + 1)] =
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(3));
for (int i = 0; i < 4; i++)
fc[m * colsPerThread + 4 * n + i] =
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
};
// Main program
for (unsigned k = 0; k < numRepK; ++k) {
for (unsigned m = 0; m < numRepM; ++m)
loadA(2 * m, 2 * k);
@@ -2741,6 +2896,9 @@ void ConvertTritonGPUToLLVM::initSharedMemory(
"Inliner pass is expected before TritonGPUToLLVM");
b.setInsertionPointToStart(&funcs[0].getBody().front());
smem = b.create<LLVM::AddressOfOp>(loc, global);
auto ptrTy =
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), 3);
smem = b.create<LLVM::BitcastOp>(loc, ptrTy, smem);
}
} // namespace

View File

@@ -87,7 +87,6 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
} else {
assert(0 && "Unimplemented usage of getShapePerCTA");
}
return shape;
}
@@ -104,7 +103,7 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
assert(0 && "Unimplemented usage of getOrder");
return {};
}
}
};
} // namespace gpu
} // namespace triton
@@ -215,9 +214,12 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
}
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
int threads = product(getWarpsPerCTA());
int numElem = product(shape);
return numElem / threads;
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of mma layout");
assert(getVersion() == 2 && "mmaLayout version = 1 is not implemented yet");
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
return elemsCol * elemsRow;
}
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {

52
python/tests/test_gemm.py Normal file
View File

@@ -0,0 +1,52 @@
import pytest
import torch
from torch.testing import assert_allclose
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr
):
offs_m = tl.arange(0, M)
offs_n = tl.arange(0, N)
offs_k = tl.arange(0, K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
c = tl.dot(a, b)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, c)
# TODO: num_warps could only be 4 for now
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
[128, 256, 32, 4],
[256, 128, 16, 4],
[128, 16, 32, 4],
[32, 128, 64, 4],
])
def test_gemm_impl(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
grid = lambda META: (1, )
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
M=SIZE_M, N=SIZE_N, K=SIZE_K,
num_warps=NUM_WARPS)
golden = torch.matmul(a, b)
torch.set_printoptions(profile="full")
assert_allclose(c, golden, rtol=1e-3, atol=1e-3)

View File

@@ -910,7 +910,7 @@ def ptx_get_version(cuda_version) -> int:
def path_to_ptxas():
prefixes = [os.environ.get("TRITON_PTXAS_PATH", ""), "", "/usr/local/cuda/"]
prefixes = [os.environ.get("TRITON_PTXAS_PATH", ""), "", os.environ.get('CUDA_PATH', default_cuda_dir())]
for prefix in prefixes:
ptxas = os.path.join(prefix, "bin", "ptxas")
if os.path.exists(ptxas):

View File

@@ -299,6 +299,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_alloc_tensor
func @basic_alloc_tensor() {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK-NEXT: llvm.bitcast
// CHECK-NEXT: llvm.mlir.constant
// CHECK-NEXT: llvm.getelementptr
// CHECK-NEXT: llvm.bitcast
@@ -315,13 +316,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_extract_slice
func @basic_extract_slice() {
// CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast %[[BASE0]]
// CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant
// CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant
// CHECK-NEXT: llvm.getelementptr %[[BASE0]][%[[OFFSET1]]]
// CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET1]]]
// CHECK-NEXT: %[[BASE2:.*]] = llvm.bitcast
// CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant
// CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]]
// CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET3]]]
// CHECK-NEXT: llvm.getelementptr %[[BASE2]][%[[OFFSET3]]]
%index = arith.constant 1 : i32
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
%1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0>
@@ -515,3 +517,20 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<16384 x i8>
// CHECK-LABEL: convert_layout_blocked_shared
func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<8xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<8xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
return
}
}