|
|
|
@@ -67,8 +67,6 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|
|
|
|
builder.getIntegerAttr(ty, value));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Add other specification if needed...
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
|
|
|
@@ -89,8 +87,12 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|
|
|
|
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
|
|
|
|
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
|
|
|
|
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
|
|
|
|
#define barrier rewriter.create<mlir::gpu::BarrierOp>(loc)
|
|
|
|
|
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
|
|
|
|
#define i32_ty rewriter.getIntegerType(32)
|
|
|
|
|
#define vec_ty(type, num) VectorType::get(num, type)
|
|
|
|
|
#define void_ty LLVM::LLVMVoidType::get(ctx)
|
|
|
|
|
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
// Creator for constant
|
|
|
|
|
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
|
|
|
@@ -109,6 +111,18 @@ namespace type = mlir::triton::type;
|
|
|
|
|
|
|
|
|
|
class TritonGPUToLLVMTypeConverter;
|
|
|
|
|
|
|
|
|
|
// TODO: keep these before we have better debug log utilities
|
|
|
|
|
template <typename T>
|
|
|
|
|
void print_array(ArrayRef<T> array, const std::string &str) {
|
|
|
|
|
std::cout << str << ": ";
|
|
|
|
|
for (const T &e : array)
|
|
|
|
|
std::cout << e << ",";
|
|
|
|
|
std::cout << std::endl;
|
|
|
|
|
}
|
|
|
|
|
template <typename T> void print_scalar(const T &e, const std::string &str) {
|
|
|
|
|
std::cout << str << ": " << e << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
|
|
|
|
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
|
|
|
|
|
// since it is not exposed on header files in mlir v14
|
|
|
|
@@ -360,6 +374,13 @@ public:
|
|
|
|
|
return threadId;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value createIndexConst(ConversionPatternRewriter &rewriter, Location loc,
|
|
|
|
|
int64_t value) const {
|
|
|
|
|
return rewriter.create<LLVM::ConstantOp>(
|
|
|
|
|
loc, this->getTypeConverter()->getIndexType(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIndexType(), value));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Convert an \param index to a multi-dim coordinate given \param shape and
|
|
|
|
|
// \param order.
|
|
|
|
|
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
|
|
|
|
@@ -859,7 +880,7 @@ struct StoreOpConversion
|
|
|
|
|
asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(Superjomn) Need to check masks before vectorize the load for all
|
|
|
|
|
// TODO(Superjomn) Need to check masks before vectorize the load for
|
|
|
|
|
// the values share one predicate? Here assume all the mask values are
|
|
|
|
|
// the same.
|
|
|
|
|
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
|
|
|
|
@@ -1060,7 +1081,6 @@ struct LoadOpConversion
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Value ptr = op.ptr();
|
|
|
|
|
Value mask = op.mask();
|
|
|
|
|
Value other = op.other();
|
|
|
|
@@ -1413,87 +1433,24 @@ public:
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
Value src = op.src();
|
|
|
|
|
Value dst = op.result();
|
|
|
|
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
|
|
|
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
|
|
|
|
Attribute srcLayout = srcTy.getEncoding();
|
|
|
|
|
Attribute dstLayout = dstTy.getEncoding();
|
|
|
|
|
if (srcLayout.isa<BlockedEncodingAttr>() &&
|
|
|
|
|
dstLayout.isa<SharedEncodingAttr>()) {
|
|
|
|
|
return lowerBlockedToShared(op, adaptor, rewriter);
|
|
|
|
|
}
|
|
|
|
|
if ((!srcLayout.isa<BlockedEncodingAttr>() &&
|
|
|
|
|
!srcLayout.isa<MmaEncodingAttr>()) ||
|
|
|
|
|
(!dstLayout.isa<BlockedEncodingAttr>() &&
|
|
|
|
|
!dstLayout.isa<MmaEncodingAttr>())) {
|
|
|
|
|
// TODO: to be implemented
|
|
|
|
|
llvm::errs() << "Unsupported ConvertLayout found";
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
|
|
|
|
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
|
|
|
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
|
|
|
|
smemBase = bit_cast(elemPtrTy, smemBase);
|
|
|
|
|
|
|
|
|
|
auto shape = dstTy.getShape();
|
|
|
|
|
unsigned rank = dstTy.getRank();
|
|
|
|
|
SmallVector<unsigned> numReplicates(rank);
|
|
|
|
|
SmallVector<unsigned> inNumCTAsEachRep(rank);
|
|
|
|
|
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
|
|
|
|
SmallVector<unsigned> inNumCTAs(rank);
|
|
|
|
|
SmallVector<unsigned> outNumCTAs(rank);
|
|
|
|
|
auto srcShapePerCTA = getShapePerCTA(srcLayout);
|
|
|
|
|
auto dstShapePerCTA = getShapePerCTA(dstLayout);
|
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
|
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
|
|
|
|
|
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
|
|
|
|
|
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
|
|
|
|
|
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
|
|
|
|
|
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
|
|
|
|
|
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
|
|
|
|
|
// TODO: confirm this
|
|
|
|
|
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
|
|
|
|
|
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
|
|
|
|
|
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
|
|
|
|
|
}
|
|
|
|
|
// Potentially we need to store for multiple CTAs in this replication
|
|
|
|
|
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
|
|
|
|
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
|
|
|
|
|
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
|
|
|
|
unsigned inVec = 0;
|
|
|
|
|
unsigned outVec = 0;
|
|
|
|
|
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
|
|
|
|
|
|
|
|
|
unsigned outElems = getElemsPerThread(dstLayout, shape);
|
|
|
|
|
auto outOrd = getOrder(dstLayout);
|
|
|
|
|
SmallVector<Value> outVals(outElems);
|
|
|
|
|
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
|
|
|
|
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
|
|
|
|
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
|
|
|
|
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
|
|
|
|
srcLayout.isa<MmaEncodingAttr>()) {
|
|
|
|
|
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
|
|
|
|
|
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
|
|
|
|
|
smemBase);
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "ConvertLayout with input layout not implemented");
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
rewriter.create<mlir::gpu::BarrierOp>(loc);
|
|
|
|
|
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
|
|
|
|
dstLayout.isa<MmaEncodingAttr>()) {
|
|
|
|
|
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
|
|
|
|
|
outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape,
|
|
|
|
|
outOrd, outVals, smemBase);
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "ConvertLayout with output layout not implemented");
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<Type> types(outElems, llvmElemTy);
|
|
|
|
|
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
|
|
|
|
|
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
return lowerDistributedToDistributed(op, adaptor, rewriter);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@@ -1508,122 +1465,334 @@ private:
|
|
|
|
|
return result;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// shared memory access for blocked or mma layout
|
|
|
|
|
// shared memory rd/st for blocked or mma layout with data padding
|
|
|
|
|
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
|
|
|
|
|
bool stNotRd, RankedTensorType type,
|
|
|
|
|
ArrayRef<unsigned> numCTAsEachRep,
|
|
|
|
|
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
|
|
|
|
ArrayRef<unsigned> paddedRepShape,
|
|
|
|
|
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
|
|
|
|
|
Value smemBase) const {
|
|
|
|
|
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
|
|
|
|
auto layout = type.getEncoding();
|
|
|
|
|
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
|
|
|
|
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
|
|
|
|
|
auto rank = type.getRank();
|
|
|
|
|
auto sizePerThread = getSizePerThread(layout);
|
|
|
|
|
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
|
|
|
|
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
|
|
|
|
SmallVector<unsigned> numCTAs(rank);
|
|
|
|
|
auto shapePerCTA = getShapePerCTA(layout);
|
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
|
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
|
|
|
|
}
|
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
|
|
|
|
|
SmallVector<Value> multiDimOffsetFirstElem;
|
|
|
|
|
Value mmaGrpId;
|
|
|
|
|
Value mmaGrpIdP8;
|
|
|
|
|
Value mmaThreadIdInGrpM2;
|
|
|
|
|
Value mmaThreadIdInGrpM2P1;
|
|
|
|
|
if (blockedLayout) {
|
|
|
|
|
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
|
|
|
|
loc, rewriter, blockedLayout, type.getShape());
|
|
|
|
|
} else if (mmaLayout) {
|
|
|
|
|
// TODO: simplify these
|
|
|
|
|
auto cast = rewriter.create<UnrealizedConversionCastOp>(
|
|
|
|
|
loc, TypeRange{llvmIndexTy},
|
|
|
|
|
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
|
|
|
|
|
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
|
|
|
|
|
Value threadId = cast.getResult(0);
|
|
|
|
|
Value warpSize = createIndexAttrConstant(
|
|
|
|
|
rewriter, loc, this->getTypeConverter()->getIndexType(), 32);
|
|
|
|
|
Value laneId = rewriter.create<LLVM::URemOp>(loc, threadId, warpSize);
|
|
|
|
|
Value fourVal = idx_val(4);
|
|
|
|
|
mmaGrpId = rewriter.create<LLVM::UDivOp>(loc, laneId, fourVal);
|
|
|
|
|
mmaGrpIdP8 = rewriter.create<LLVM::AddOp>(loc, mmaGrpId, idx_val(8));
|
|
|
|
|
Value mmaThreadIdInGrp =
|
|
|
|
|
rewriter.create<LLVM::URemOp>(loc, laneId, fourVal);
|
|
|
|
|
mmaThreadIdInGrpM2 =
|
|
|
|
|
rewriter.create<LLVM::MulOp>(loc, mmaThreadIdInGrp, idx_val(2));
|
|
|
|
|
mmaThreadIdInGrpM2P1 =
|
|
|
|
|
rewriter.create<LLVM::AddOp>(loc, mmaThreadIdInGrpM2, idx_val(1));
|
|
|
|
|
}
|
|
|
|
|
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
|
|
|
|
auto multiDimCTAInRepId =
|
|
|
|
|
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
|
|
|
|
|
SmallVector<unsigned> multiDimCTAId(rank);
|
|
|
|
|
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
|
|
|
|
|
auto d = it.index();
|
|
|
|
|
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
|
|
|
|
}
|
|
|
|
|
Value smemBase) const;
|
|
|
|
|
|
|
|
|
|
unsigned linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs);
|
|
|
|
|
// TODO: This is actually redundant index calculation, we should
|
|
|
|
|
// consider of caching the index calculation result in case
|
|
|
|
|
// of performance issue observed.
|
|
|
|
|
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
|
|
|
|
SmallVector<Value> multiDimOffset(rank);
|
|
|
|
|
if (blockedLayout) {
|
|
|
|
|
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
|
|
|
|
elemId, blockedLayout.getSizePerThread());
|
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
|
multiDimOffset[d] = rewriter.create<LLVM::AddOp>(
|
|
|
|
|
loc, multiDimOffsetFirstElem[d],
|
|
|
|
|
createIndexAttrConstant(rewriter, loc, llvmIndexTy,
|
|
|
|
|
multiDimCTAInRepId[d] * shapePerCTA[d] +
|
|
|
|
|
multiDimElemId[d]));
|
|
|
|
|
}
|
|
|
|
|
} else if (mmaLayout) {
|
|
|
|
|
assert(rank == 2);
|
|
|
|
|
assert(mmaLayout.getVersion() == 2 &&
|
|
|
|
|
"mmaLayout ver1 not implemented yet");
|
|
|
|
|
multiDimOffset[0] = elemId < 2 ? mmaGrpId : mmaGrpIdP8;
|
|
|
|
|
multiDimOffset[1] =
|
|
|
|
|
elemId % 2 == 0 ? mmaThreadIdInGrpM2 : mmaThreadIdInGrpM2P1;
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "unexpected layout in processReplica");
|
|
|
|
|
// blocked/mma -> blocked/mma.
|
|
|
|
|
// Data padding in shared memory to avoid bank confict.
|
|
|
|
|
LogicalResult
|
|
|
|
|
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
|
|
|
|
|
OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const;
|
|
|
|
|
|
|
|
|
|
// blocked -> shared.
|
|
|
|
|
// Swizzling in shared memory to avoid bank conflict. Normally used for
|
|
|
|
|
// A/B operands of dots.
|
|
|
|
|
LogicalResult lowerBlockedToShared(triton::gpu::ConvertLayoutOp op,
|
|
|
|
|
OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void ConvertLayoutOpConversion::processReplica(
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter, bool stNotRd,
|
|
|
|
|
RankedTensorType type, ArrayRef<unsigned> numCTAsEachRep,
|
|
|
|
|
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
|
|
|
|
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> outOrd,
|
|
|
|
|
SmallVector<Value> &vals, Value smemBase) const {
|
|
|
|
|
unsigned accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
|
|
|
|
auto layout = type.getEncoding();
|
|
|
|
|
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
|
|
|
|
|
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
|
|
|
|
|
auto rank = type.getRank();
|
|
|
|
|
auto sizePerThread = getSizePerThread(layout);
|
|
|
|
|
auto accumSizePerThread = product<unsigned>(sizePerThread);
|
|
|
|
|
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
|
|
|
|
SmallVector<unsigned> numCTAs(rank);
|
|
|
|
|
auto shapePerCTA = getShapePerCTA(layout);
|
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
|
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
|
|
|
|
}
|
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
|
|
|
|
|
SmallVector<Value> multiDimOffsetFirstElem;
|
|
|
|
|
SmallVector<Value> mmaColIdx(2);
|
|
|
|
|
SmallVector<Value> mmaRowIdx(2);
|
|
|
|
|
if (blockedLayout) {
|
|
|
|
|
multiDimOffsetFirstElem = emitBaseIndexForBlockedLayout(
|
|
|
|
|
loc, rewriter, blockedLayout, type.getShape());
|
|
|
|
|
} else if (mmaLayout) {
|
|
|
|
|
Value threadId = getThreadId(rewriter, loc);
|
|
|
|
|
Value warpSize = idx_val(32);
|
|
|
|
|
Value laneId = urem(threadId, warpSize);
|
|
|
|
|
Value warpId = udiv(threadId, warpSize);
|
|
|
|
|
// auto multiDimWarpId =
|
|
|
|
|
// delinearize(rewriter, loc, warpId, mmaLayout.getWarpsPerCTA());
|
|
|
|
|
// TODO: double confirm if its document bug or DotConversion's Bug
|
|
|
|
|
SmallVector<Value> multiDimWarpId(2);
|
|
|
|
|
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
|
|
|
|
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
|
|
|
|
Value four = idx_val(4);
|
|
|
|
|
Value mmaGrpId = udiv(laneId, four);
|
|
|
|
|
Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8));
|
|
|
|
|
Value mmaThreadIdInGrp = urem(laneId, four);
|
|
|
|
|
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, idx_val(2));
|
|
|
|
|
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, idx_val(1));
|
|
|
|
|
Value colWarpOffset = mul(multiDimWarpId[0], idx_val(16));
|
|
|
|
|
mmaColIdx[0] = add(mmaGrpId, colWarpOffset);
|
|
|
|
|
mmaColIdx[1] = add(mmaGrpIdP8, colWarpOffset);
|
|
|
|
|
Value rowWarpOffset = mul(multiDimWarpId[1], idx_val(8));
|
|
|
|
|
mmaRowIdx[0] = add(mmaThreadIdInGrpM2, rowWarpOffset);
|
|
|
|
|
mmaRowIdx[1] = add(mmaThreadIdInGrpM2P1, rowWarpOffset);
|
|
|
|
|
}
|
|
|
|
|
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
|
|
|
|
|
auto multiDimCTAInRepId = getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep);
|
|
|
|
|
SmallVector<unsigned> multiDimCTAId(rank);
|
|
|
|
|
for (auto it : llvm::enumerate(multiDimCTAInRepId)) {
|
|
|
|
|
auto d = it.index();
|
|
|
|
|
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsigned linearCTAId = getLinearIndex<unsigned>(multiDimCTAId, numCTAs);
|
|
|
|
|
// TODO: This is actually redundant index calculation, we should
|
|
|
|
|
// consider of caching the index calculation result in case
|
|
|
|
|
// of performance issue observed.
|
|
|
|
|
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
|
|
|
|
SmallVector<Value> multiDimOffset(rank);
|
|
|
|
|
if (blockedLayout) {
|
|
|
|
|
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
|
|
|
|
elemId, blockedLayout.getSizePerThread());
|
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
|
multiDimOffset[d] =
|
|
|
|
|
add(multiDimOffsetFirstElem[d],
|
|
|
|
|
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
|
|
|
|
multiDimElemId[d]));
|
|
|
|
|
}
|
|
|
|
|
Value offset =
|
|
|
|
|
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
|
|
|
|
|
reorder<unsigned>(paddedRepShape, outOrd));
|
|
|
|
|
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
|
|
|
|
Value ptr = gep(elemPtrTy, smemBase, offset);
|
|
|
|
|
auto vecTy = vec_ty(llvmElemTy, vec);
|
|
|
|
|
ptr = bit_cast(LLVM::LLVMPointerType::get(vecTy, 3), ptr);
|
|
|
|
|
if (stNotRd) {
|
|
|
|
|
Value valVec = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
|
|
|
|
for (unsigned v = 0; v < vec; ++v) {
|
|
|
|
|
Value vVal = createIndexAttrConstant(
|
|
|
|
|
rewriter, loc, getTypeConverter()->getIndexType(), v);
|
|
|
|
|
valVec = insert_element(
|
|
|
|
|
vecTy, valVec,
|
|
|
|
|
vals[elemId + linearCTAId * accumSizePerThread + v], vVal);
|
|
|
|
|
}
|
|
|
|
|
store(valVec, ptr);
|
|
|
|
|
} else {
|
|
|
|
|
Value valVec = load(ptr);
|
|
|
|
|
for (unsigned v = 0; v < vec; ++v) {
|
|
|
|
|
Value vVal = createIndexAttrConstant(
|
|
|
|
|
rewriter, loc, getTypeConverter()->getIndexType(), v);
|
|
|
|
|
vals[elemId + linearCTAId * accumSizePerThread + v] =
|
|
|
|
|
extract_element(llvmElemTy, valVec, vVal);
|
|
|
|
|
}
|
|
|
|
|
} else if (mmaLayout) {
|
|
|
|
|
assert(rank == 2);
|
|
|
|
|
assert(mmaLayout.getVersion() == 2 &&
|
|
|
|
|
"mmaLayout ver1 not implemented yet");
|
|
|
|
|
multiDimOffset[0] = elemId < 2 ? mmaColIdx[0] : mmaColIdx[1];
|
|
|
|
|
multiDimOffset[1] = elemId % 2 == 0 ? mmaRowIdx[0] : mmaRowIdx[1];
|
|
|
|
|
multiDimOffset[0] = add(
|
|
|
|
|
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
|
|
|
|
multiDimOffset[1] = add(
|
|
|
|
|
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "unexpected layout in processReplica");
|
|
|
|
|
}
|
|
|
|
|
Value offset =
|
|
|
|
|
linearize(rewriter, loc, reorder<Value>(multiDimOffset, outOrd),
|
|
|
|
|
reorder<unsigned>(paddedRepShape, outOrd));
|
|
|
|
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
|
|
|
|
Value ptr = gep(elemPtrTy, smemBase, offset);
|
|
|
|
|
auto vecTy = vec_ty(llvmElemTy, vec);
|
|
|
|
|
ptr = bit_cast(ptr_ty(vecTy, 3), ptr);
|
|
|
|
|
if (stNotRd) {
|
|
|
|
|
Value valVec = undef(vecTy);
|
|
|
|
|
for (unsigned v = 0; v < vec; ++v) {
|
|
|
|
|
valVec = insert_element(
|
|
|
|
|
vecTy, valVec,
|
|
|
|
|
vals[elemId + linearCTAId * accumSizePerThread + v], idx_val(v));
|
|
|
|
|
}
|
|
|
|
|
store(valVec, ptr);
|
|
|
|
|
} else {
|
|
|
|
|
Value valVec = load(ptr);
|
|
|
|
|
for (unsigned v = 0; v < vec; ++v) {
|
|
|
|
|
vals[elemId + linearCTAId * accumSizePerThread + v] =
|
|
|
|
|
extract_element(llvmElemTy, valVec, idx_val(v));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
|
|
|
|
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
Value src = op.src();
|
|
|
|
|
Value dst = op.result();
|
|
|
|
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
|
|
|
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
|
|
|
|
Attribute srcLayout = srcTy.getEncoding();
|
|
|
|
|
Attribute dstLayout = dstTy.getEncoding();
|
|
|
|
|
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
|
|
|
|
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
|
|
|
|
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
|
|
|
|
smemBase = bit_cast(elemPtrTy, smemBase);
|
|
|
|
|
auto shape = dstTy.getShape();
|
|
|
|
|
unsigned rank = dstTy.getRank();
|
|
|
|
|
SmallVector<unsigned> numReplicates(rank);
|
|
|
|
|
SmallVector<unsigned> inNumCTAsEachRep(rank);
|
|
|
|
|
SmallVector<unsigned> outNumCTAsEachRep(rank);
|
|
|
|
|
SmallVector<unsigned> inNumCTAs(rank);
|
|
|
|
|
SmallVector<unsigned> outNumCTAs(rank);
|
|
|
|
|
auto srcShapePerCTA = getShapePerCTA(srcLayout);
|
|
|
|
|
auto dstShapePerCTA = getShapePerCTA(dstLayout);
|
|
|
|
|
for (unsigned d = 0; d < rank; ++d) {
|
|
|
|
|
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
|
|
|
|
|
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
|
|
|
|
|
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
|
|
|
|
|
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
|
|
|
|
|
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
|
|
|
|
|
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
|
|
|
|
|
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
|
|
|
|
|
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
|
|
|
|
|
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
|
|
|
|
|
}
|
|
|
|
|
// Potentially we need to store for multiple CTAs in this replication
|
|
|
|
|
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
|
|
|
|
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
|
|
|
|
|
auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter);
|
|
|
|
|
unsigned inVec = 0;
|
|
|
|
|
unsigned outVec = 0;
|
|
|
|
|
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
|
|
|
|
|
|
|
|
|
unsigned outElems = getElemsPerThread(dstLayout, shape);
|
|
|
|
|
auto outOrd = getOrder(dstLayout);
|
|
|
|
|
SmallVector<Value> outVals(outElems);
|
|
|
|
|
|
|
|
|
|
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
|
|
|
|
|
auto multiDimRepId = getMultiDimIndex<unsigned>(repId, numReplicates);
|
|
|
|
|
barrier;
|
|
|
|
|
if (srcLayout.isa<BlockedEncodingAttr>() ||
|
|
|
|
|
srcLayout.isa<MmaEncodingAttr>()) {
|
|
|
|
|
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
|
|
|
|
|
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
|
|
|
|
|
smemBase);
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "ConvertLayout with input layout not implemented");
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
barrier;
|
|
|
|
|
if (dstLayout.isa<BlockedEncodingAttr>() ||
|
|
|
|
|
dstLayout.isa<MmaEncodingAttr>()) {
|
|
|
|
|
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
|
|
|
|
|
multiDimRepId, outVec, paddedRepShape, outOrd, outVals,
|
|
|
|
|
smemBase);
|
|
|
|
|
} else {
|
|
|
|
|
assert(0 && "ConvertLayout with output layout not implemented");
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<Type> types(outElems, llvmElemTy);
|
|
|
|
|
Type structTy = struct_ty(getContext(), types);
|
|
|
|
|
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|
|
|
|
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
Value src = op.src();
|
|
|
|
|
Value dst = op.result();
|
|
|
|
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
|
|
|
|
auto dstTy = dst.getType().cast<RankedTensorType>();
|
|
|
|
|
auto srcShape = srcTy.getShape();
|
|
|
|
|
assert(srcShape.size() == 2 &&
|
|
|
|
|
"Unexpected rank of ConvertLayout(blocked->shared)");
|
|
|
|
|
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
|
|
|
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
auto inOrd = srcBlockedLayout.getOrder();
|
|
|
|
|
auto outOrd = dstSharedLayout.getOrder();
|
|
|
|
|
unsigned inVec =
|
|
|
|
|
inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1;
|
|
|
|
|
unsigned outVec = dstSharedLayout.getVec();
|
|
|
|
|
unsigned minVec = std::min(outVec, inVec);
|
|
|
|
|
unsigned perPhase = dstSharedLayout.getPerPhase();
|
|
|
|
|
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
|
|
|
|
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
|
|
|
|
|
auto inVals = getElementsFromStruct(loc, adaptor.src(), numElems, rewriter);
|
|
|
|
|
unsigned srcAccumSizeInThreads =
|
|
|
|
|
product<unsigned>(srcBlockedLayout.getSizePerThread());
|
|
|
|
|
auto elemTy = srcTy.getElementType();
|
|
|
|
|
auto wordTy = vec_ty(elemTy, minVec);
|
|
|
|
|
|
|
|
|
|
// TODO: [goostavz] We should make a cache for the calculation of
|
|
|
|
|
// emitBaseIndexForBlockedLayout in case backend compiler not being able to
|
|
|
|
|
// optimize that
|
|
|
|
|
SmallVector<Value> multiDimOffsetFirstElem =
|
|
|
|
|
emitBaseIndexForBlockedLayout(loc, rewriter, srcBlockedLayout, srcShape);
|
|
|
|
|
SmallVector<unsigned> srcShapePerCTA = getShapePerCTA(srcBlockedLayout);
|
|
|
|
|
SmallVector<unsigned> reps{ceil<unsigned>(srcShape[0], srcShapePerCTA[0]),
|
|
|
|
|
ceil<unsigned>(srcShape[1], srcShapePerCTA[1])};
|
|
|
|
|
|
|
|
|
|
// Visit each input value in the order they are placed in inVals
|
|
|
|
|
//
|
|
|
|
|
// Please note that the order was not awaring of blockLayout.getOrder(),
|
|
|
|
|
// thus the adjacent elems may not belong to a same word. This could be
|
|
|
|
|
// improved if we update the elements order by emitIndicesForBlockedLayout()
|
|
|
|
|
SmallVector<unsigned> wordsInEachRep(2);
|
|
|
|
|
wordsInEachRep[0] = inOrd[0] == 0
|
|
|
|
|
? srcBlockedLayout.getSizePerThread()[0] / minVec
|
|
|
|
|
: srcBlockedLayout.getSizePerThread()[0];
|
|
|
|
|
wordsInEachRep[1] = inOrd[0] == 0
|
|
|
|
|
? srcBlockedLayout.getSizePerThread()[1]
|
|
|
|
|
: srcBlockedLayout.getSizePerThread()[1] / minVec;
|
|
|
|
|
Value outVecVal = idx_val(outVec);
|
|
|
|
|
Value minVecVal = idx_val(minVec);
|
|
|
|
|
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
|
|
|
|
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
|
|
|
|
smemBase = bit_cast(elemPtrTy, smemBase);
|
|
|
|
|
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
|
|
|
|
SmallVector<Value> wordVecs(numWordsEachRep);
|
|
|
|
|
for (unsigned i = 0; i < numElems; ++i) {
|
|
|
|
|
if (i % srcAccumSizeInThreads == 0) {
|
|
|
|
|
// start of a replication
|
|
|
|
|
for (unsigned w = 0; w < numWordsEachRep; ++w) {
|
|
|
|
|
wordVecs[w] = undef(wordTy);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
|
|
|
|
|
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
|
|
|
|
|
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread());
|
|
|
|
|
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
|
|
|
|
|
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
|
|
|
|
|
unsigned wordVecIdx =
|
|
|
|
|
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep);
|
|
|
|
|
wordVecs[wordVecIdx] =
|
|
|
|
|
insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], idx_val(pos));
|
|
|
|
|
|
|
|
|
|
if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) {
|
|
|
|
|
// end of replication, store the vectors into shared memory
|
|
|
|
|
unsigned linearRepIdx = i / srcAccumSizeInThreads;
|
|
|
|
|
auto multiDimRepIdx = getMultiDimIndex<unsigned>(linearRepIdx, reps);
|
|
|
|
|
for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep;
|
|
|
|
|
++linearWordIdx) {
|
|
|
|
|
// step 1: recover the multidim_index from the index of input_elements
|
|
|
|
|
auto multiDimWordIdx =
|
|
|
|
|
getMultiDimIndex<unsigned>(linearWordIdx, wordsInEachRep);
|
|
|
|
|
SmallVector<Value> multiDimIdx(2);
|
|
|
|
|
auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] +
|
|
|
|
|
multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1);
|
|
|
|
|
auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] +
|
|
|
|
|
multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1);
|
|
|
|
|
multiDimIdx[0] = add(multiDimOffsetFirstElem[0], idx_val(wordOffset0));
|
|
|
|
|
multiDimIdx[1] = add(multiDimOffsetFirstElem[1], idx_val(wordOffset1));
|
|
|
|
|
|
|
|
|
|
// step 2: do swizzling
|
|
|
|
|
Value remained = urem(multiDimIdx[inOrd[0]], outVecVal);
|
|
|
|
|
multiDimIdx[inOrd[0]] = udiv(multiDimIdx[inOrd[0]], outVecVal);
|
|
|
|
|
Value off_1 = mul(multiDimIdx[inOrd[1]], idx_val(srcShape[inOrd[0]]));
|
|
|
|
|
Value phaseId = udiv(multiDimIdx[inOrd[1]], idx_val(perPhase));
|
|
|
|
|
phaseId = urem(phaseId, idx_val(maxPhase));
|
|
|
|
|
Value off_0 = xor_(multiDimIdx[inOrd[0]], phaseId);
|
|
|
|
|
off_0 = mul(off_0, outVecVal);
|
|
|
|
|
remained = udiv(remained, minVecVal);
|
|
|
|
|
off_0 = add(off_0, mul(remained, minVecVal));
|
|
|
|
|
Value offset = add(off_1, off_0);
|
|
|
|
|
|
|
|
|
|
// step 3: store
|
|
|
|
|
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
|
|
|
|
smemAddr = bit_cast(ptr_ty(wordTy, 3), smemAddr);
|
|
|
|
|
store(wordVecs[linearWordIdx], smemAddr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// TODO: double confirm if the Barrier is necessary here
|
|
|
|
|
barrier;
|
|
|
|
|
rewriter.replaceOp(op, smemBase);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
/// ====================== dot codegen begin ==========================
|
|
|
|
|
|
|
|
|
|
// Data loader for mma.16816 instruction.
|
|
|
|
@@ -1843,7 +2012,7 @@ public:
|
|
|
|
|
|
|
|
|
|
if (canUseLdmatrix)
|
|
|
|
|
ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]);
|
|
|
|
|
else if (elemBytes == 4 && needTrans) // tf32 & trans
|
|
|
|
|
else if (elemBytes == 4 && needTrans)
|
|
|
|
|
ptrIdx = matIdx[order[0]];
|
|
|
|
|
else if (elemBytes == 1 && needTrans)
|
|
|
|
|
ptrIdx = matIdx[order[0]] * 4;
|
|
|
|
@@ -2127,10 +2296,6 @@ struct DotOpConversionHelper {
|
|
|
|
|
.cast<RankedTensorType>()
|
|
|
|
|
.getEncoding()
|
|
|
|
|
.cast<MmaEncodingAttr>();
|
|
|
|
|
|
|
|
|
|
ATensorTy = A.getType().cast<RankedTensorType>();
|
|
|
|
|
BTensorTy = B.getType().cast<RankedTensorType>();
|
|
|
|
|
DTensorTy = D.getType().cast<RankedTensorType>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Load SplatLike C which contains a constVal. It simply returns 4 fp32
|
|
|
|
@@ -2469,7 +2634,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|
|
|
|
bool needTrans = kOrder != order[0];
|
|
|
|
|
|
|
|
|
|
// (a, b) is the coordinate.
|
|
|
|
|
auto load = [&, loader, ptrs, offs, needTrans](int a, int b) {
|
|
|
|
|
auto load = [=, &vals, &helper, &ld2](int a, int b) {
|
|
|
|
|
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
|
|
|
|
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
|
|
|
|
|
ptrs, helper.getMatType(), helper.getShemPtrTy());
|
|
|
|
@@ -2490,78 +2655,68 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::function<void(int, int)> loadA;
|
|
|
|
|
std::function<void(int, int)> loadB = getLoadMatrixFn(
|
|
|
|
|
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
|
|
|
|
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
|
|
|
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
|
|
|
|
|
|
|
|
|
if (aTensorTy.getEncoding()
|
|
|
|
|
.dyn_cast<SharedEncodingAttr>()) { // load from smem
|
|
|
|
|
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
|
|
|
|
// load from smem
|
|
|
|
|
loadA = getLoadMatrixFn(
|
|
|
|
|
A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
|
|
|
|
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
|
|
|
|
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
|
|
|
|
} else if (auto blockedLayout =
|
|
|
|
|
aTensorTy.getEncoding()
|
|
|
|
|
.dyn_cast<BlockedEncodingAttr>()) { // load from registers,
|
|
|
|
|
// used in gemm fuse
|
|
|
|
|
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
|
|
|
|
// load from registers, used in gemm fuse
|
|
|
|
|
// TODO(Superjomn) Port the logic.
|
|
|
|
|
assert(false && "Loading A from register is not supported yet.");
|
|
|
|
|
} else {
|
|
|
|
|
assert(false && "A's layout is not supported.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const unsigned mStride = numRepN * 2;
|
|
|
|
|
SmallVector<Value> fc(numRepM * mStride + numRepN * 2);
|
|
|
|
|
std::function<void(int, int)> loadB = getLoadMatrixFn(
|
|
|
|
|
B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
|
|
|
|
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
|
|
|
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
|
|
|
|
|
|
|
|
|
const int fcSize = 4 * numRepM * numRepN;
|
|
|
|
|
SmallVector<Value> fc(fcSize);
|
|
|
|
|
|
|
|
|
|
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
|
|
|
|
|
// shared layout or blocked layout, we will support them by expanding
|
|
|
|
|
// convert_layout.
|
|
|
|
|
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
|
|
|
|
|
assert(hc.size() == 4UL && "Only splat-like C is supported now");
|
|
|
|
|
for (int i = 0; i < fc.size(); i++)
|
|
|
|
|
fc[i] = hc[0];
|
|
|
|
|
|
|
|
|
|
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
|
|
|
|
unsigned colsPerThread = numRepN * 2;
|
|
|
|
|
PTXBuilder builder;
|
|
|
|
|
|
|
|
|
|
auto &mma = *builder.create(helper.getMmaInstr().str());
|
|
|
|
|
|
|
|
|
|
auto retArgs = builder.newListOperand(4, "=r");
|
|
|
|
|
|
|
|
|
|
auto aArgs = builder.newListOperand({
|
|
|
|
|
{ha[{m, k}], "r"},
|
|
|
|
|
{ha[{m + 1, k}], "r"},
|
|
|
|
|
{ha[{m, k + 1}], "r"},
|
|
|
|
|
{ha[{m + 1, k + 1}], "r"},
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
auto bArgs =
|
|
|
|
|
builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}});
|
|
|
|
|
|
|
|
|
|
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
|
|
|
|
|
// shared layout or blocked layout, we will support them by expanding
|
|
|
|
|
// convert_layout.
|
|
|
|
|
auto hc = helper.loadSplatLikeC(C, loc, rewriter);
|
|
|
|
|
assert(hc.size() == 4UL && "Only splat-like C is supported now");
|
|
|
|
|
|
|
|
|
|
auto cArgs = builder.newListOperand();
|
|
|
|
|
for (int i = 0; i < hc.size(); ++i) {
|
|
|
|
|
cArgs->listAppend(builder.newOperand(
|
|
|
|
|
hc[i], std::to_string(i))); // reuse the output registers
|
|
|
|
|
for (int i = 0; i < 4; ++i) {
|
|
|
|
|
cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i],
|
|
|
|
|
std::to_string(i)));
|
|
|
|
|
// reuse the output registers
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mma(retArgs, aArgs, bArgs, cArgs);
|
|
|
|
|
|
|
|
|
|
Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType());
|
|
|
|
|
|
|
|
|
|
auto getIntAttr = [&](int v) {
|
|
|
|
|
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
fc[(m + 0) * mStride + (n * 2 + 0)] =
|
|
|
|
|
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(0));
|
|
|
|
|
fc[(m + 0) * mStride + (n * 2 + 1)] =
|
|
|
|
|
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(1));
|
|
|
|
|
fc[(m + 1) * mStride + (n * 2 + 0)] =
|
|
|
|
|
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(2));
|
|
|
|
|
fc[(m + 1) * mStride + (n * 2 + 1)] =
|
|
|
|
|
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(3));
|
|
|
|
|
for (int i = 0; i < 4; i++)
|
|
|
|
|
fc[m * colsPerThread + 4 * n + i] =
|
|
|
|
|
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Main program
|
|
|
|
|
|
|
|
|
|
for (unsigned k = 0; k < numRepK; ++k) {
|
|
|
|
|
for (unsigned m = 0; m < numRepM; ++m)
|
|
|
|
|
loadA(2 * m, 2 * k);
|
|
|
|
@@ -2741,6 +2896,9 @@ void ConvertTritonGPUToLLVM::initSharedMemory(
|
|
|
|
|
"Inliner pass is expected before TritonGPUToLLVM");
|
|
|
|
|
b.setInsertionPointToStart(&funcs[0].getBody().front());
|
|
|
|
|
smem = b.create<LLVM::AddressOfOp>(loc, global);
|
|
|
|
|
auto ptrTy =
|
|
|
|
|
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), 3);
|
|
|
|
|
smem = b.create<LLVM::BitcastOp>(loc, ptrTy, smem);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|