|
|
|
@@ -457,6 +457,25 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
|
|
|
|
|
}
|
|
|
|
|
return results;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static SharedMemoryObject
|
|
|
|
|
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
|
|
|
|
ConversionPatternRewriter &rewriter) {
|
|
|
|
|
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
|
|
|
|
return SharedMemoryObject(/*base=*/elems[0],
|
|
|
|
|
/*strides=*/{elems.begin() + 1, elems.end()});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Value
|
|
|
|
|
getStructFromSharedMemoryObject(Location loc,
|
|
|
|
|
const SharedMemoryObject &smemObj,
|
|
|
|
|
ConversionPatternRewriter &rewriter) {
|
|
|
|
|
auto elems = smemObj.getElems();
|
|
|
|
|
auto types = smemObj.getTypes();
|
|
|
|
|
auto structTy =
|
|
|
|
|
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
|
|
|
|
return getStructFromElements(loc, elems, rewriter, structTy);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename SourceOp>
|
|
|
|
@@ -830,25 +849,6 @@ public:
|
|
|
|
|
return base;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static SharedMemoryObject
|
|
|
|
|
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
|
|
|
|
ConversionPatternRewriter &rewriter) {
|
|
|
|
|
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
|
|
|
|
return SharedMemoryObject(/*base=*/elems[0],
|
|
|
|
|
/*strides=*/{elems.begin() + 1, elems.end()});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Value
|
|
|
|
|
getStructFromSharedMemoryObject(Location loc,
|
|
|
|
|
const SharedMemoryObject &smemObj,
|
|
|
|
|
ConversionPatternRewriter &rewriter) {
|
|
|
|
|
auto elems = smemObj.getElems();
|
|
|
|
|
auto types = smemObj.getTypes();
|
|
|
|
|
auto structTy =
|
|
|
|
|
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
|
|
|
|
return getStructFromElements(loc, elems, rewriter, structTy);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
const Allocation *allocation;
|
|
|
|
|
Value smem;
|
|
|
|
@@ -3566,7 +3566,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (op.getType().cast<RankedTensorType>().getElementType().isF32() &&
|
|
|
|
|
A.getType().cast<RankedTensorType>().getElementType().isF32())
|
|
|
|
|
A.getType().cast<RankedTensorType>().getElementType().isF32() &&
|
|
|
|
|
!op.allowTF32())
|
|
|
|
|
return convertFMADot(op, adaptor, rewriter);
|
|
|
|
|
|
|
|
|
|
llvm::report_fatal_error(
|
|
|
|
@@ -4385,6 +4386,90 @@ private:
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Helper for conversion of FMA DotOp.
|
|
|
|
|
struct DotOpFMAConversionHelper {
|
|
|
|
|
Attribute layout;
|
|
|
|
|
MLIRContext *ctx{};
|
|
|
|
|
|
|
|
|
|
using ValueTable = std::map<std::pair<int, int>, Value>;
|
|
|
|
|
|
|
|
|
|
explicit DotOpFMAConversionHelper(Attribute layout)
|
|
|
|
|
: layout(layout), ctx(layout.getContext()) {}
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> getThreadIds(Value threadId,
|
|
|
|
|
ArrayRef<unsigned> shapePerCTA,
|
|
|
|
|
ArrayRef<unsigned> order,
|
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
|
Location loc) const;
|
|
|
|
|
|
|
|
|
|
Value loadA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter) const;
|
|
|
|
|
|
|
|
|
|
Value loadB(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
|
|
|
|
|
Location loc, ConversionPatternRewriter &rewriter) const;
|
|
|
|
|
|
|
|
|
|
ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA,
|
|
|
|
|
int sizePerThread,
|
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
|
Location loc) const;
|
|
|
|
|
|
|
|
|
|
Value getStructFromValueTable(ValueTable vals,
|
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
|
Location loc) const {
|
|
|
|
|
SmallVector<Type> elemTypes(vals.size(), f32_ty);
|
|
|
|
|
SmallVector<Value> elems;
|
|
|
|
|
elems.reserve(vals.size());
|
|
|
|
|
for (auto &item : vals) {
|
|
|
|
|
elems.push_back(item.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type structTy = struct_ty(elemTypes);
|
|
|
|
|
return getStructFromElements(loc, elems, rewriter, structTy);
|
|
|
|
|
}
|
|
|
|
|
// get number of elements per thread for $a or $b.
|
|
|
|
|
static int getNumElemsPerThread(ArrayRef<int64_t> shape,
|
|
|
|
|
DotOperandEncodingAttr dotOpLayout) {
|
|
|
|
|
auto blockedLayout = dotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
|
|
|
|
auto shapePerCTA = getShapePerCTA(blockedLayout);
|
|
|
|
|
auto sizePerThread = getSizePerThread(blockedLayout);
|
|
|
|
|
auto order = blockedLayout.getOrder();
|
|
|
|
|
|
|
|
|
|
// TODO[Superjomn]: we assume the k aixs is fixed for $a and $b here, fix it
|
|
|
|
|
// if not.
|
|
|
|
|
int K = dotOpLayout.getOpIdx() == 0 ? shape[1] : shape[0];
|
|
|
|
|
int otherDim = dotOpLayout.getOpIdx() == 1 ? shape[1] : shape[0];
|
|
|
|
|
|
|
|
|
|
bool isM = dotOpLayout.getOpIdx() == 0;
|
|
|
|
|
int shapePerCTAMN = getShapePerCTAForMN(blockedLayout, isM);
|
|
|
|
|
int sizePerThreadMN = getsizePerThreadForMN(blockedLayout, isM);
|
|
|
|
|
return K * std::max<int>(otherDim / shapePerCTAMN, 1) * sizePerThreadMN;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get shapePerCTA for M or N axis.
|
|
|
|
|
static int getShapePerCTAForMN(BlockedEncodingAttr layout, bool isM) {
|
|
|
|
|
auto order = layout.getOrder();
|
|
|
|
|
auto shapePerCTA = getShapePerCTA(layout);
|
|
|
|
|
|
|
|
|
|
int mShapePerCTA =
|
|
|
|
|
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
|
|
|
|
int nShapePerCTA =
|
|
|
|
|
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
|
|
|
|
return isM ? mShapePerCTA : nShapePerCTA;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get sizePerThread for M or N axis.
|
|
|
|
|
static int getsizePerThreadForMN(BlockedEncodingAttr layout, bool isM) {
|
|
|
|
|
auto order = layout.getOrder();
|
|
|
|
|
auto sizePerThread = getSizePerThread(layout);
|
|
|
|
|
|
|
|
|
|
int mSizePerThread =
|
|
|
|
|
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
|
|
|
|
int nSizePerThread =
|
|
|
|
|
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
|
|
|
|
return isM ? mSizePerThread : nSizePerThread;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
|
|
|
|
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
|
|
|
|
@@ -4393,14 +4478,15 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
|
|
|
|
Value src = op.src();
|
|
|
|
|
Value dst = op.result();
|
|
|
|
|
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
|
|
|
|
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
|
|
|
|
|
// is an attribute of DotOp.
|
|
|
|
|
// TODO[Superjomn]: allowTF32 is not accessible here for it is an attribute of
|
|
|
|
|
// an Op instance.
|
|
|
|
|
bool allowTF32 = false;
|
|
|
|
|
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32,
|
|
|
|
|
mmaLayout.getVersion());
|
|
|
|
|
|
|
|
|
|
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
|
|
|
|
|
Value res;
|
|
|
|
|
|
|
|
|
|
if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2
|
|
|
|
|
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
|
|
|
|
rewriter, getTypeConverter(),
|
|
|
|
@@ -4459,7 +4545,25 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
|
|
|
|
} else if (auto blockedLayout =
|
|
|
|
|
dotOperandLayout.getParent()
|
|
|
|
|
.dyn_cast_or_null<BlockedEncodingAttr>()) {
|
|
|
|
|
assert(false && "Blocked layout is not supported yet");
|
|
|
|
|
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
|
|
|
|
|
// is an attribute of DotOp.
|
|
|
|
|
bool allowTF32 = false;
|
|
|
|
|
bool isFMADot = dstTensorTy.getElementType().isF32() && !allowTF32;
|
|
|
|
|
if (isFMADot) {
|
|
|
|
|
auto dotOpLayout =
|
|
|
|
|
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
|
|
|
|
auto blockedLayout = dotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
|
|
|
|
DotOpFMAConversionHelper helper(blockedLayout);
|
|
|
|
|
auto thread = getThreadId(rewriter, loc);
|
|
|
|
|
if (dotOpLayout.getOpIdx() == 0) { // $a
|
|
|
|
|
res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc,
|
|
|
|
|
rewriter);
|
|
|
|
|
} else { // $b
|
|
|
|
|
res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc,
|
|
|
|
|
rewriter);
|
|
|
|
|
}
|
|
|
|
|
} else
|
|
|
|
|
assert(false && "Unsupported dot operand layout found");
|
|
|
|
|
} else {
|
|
|
|
|
assert(false && "Unsupported dot operand layout found");
|
|
|
|
|
}
|
|
|
|
@@ -4925,6 +5029,183 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand(
|
|
|
|
|
return rcds;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value DotOpFMAConversionHelper::loadA(
|
|
|
|
|
Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, Location loc,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
auto aTensorTy = A.getType().cast<RankedTensorType>();
|
|
|
|
|
auto aLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
auto aShape = aTensorTy.getShape();
|
|
|
|
|
|
|
|
|
|
auto aOrder = aLayout.getOrder();
|
|
|
|
|
auto order = dLayout.getOrder();
|
|
|
|
|
|
|
|
|
|
bool isARow = aOrder[0] == 1;
|
|
|
|
|
|
|
|
|
|
int strideAM = isARow ? aShape[1] : 1;
|
|
|
|
|
int strideAK = isARow ? 1 : aShape[0];
|
|
|
|
|
int strideA0 = isARow ? strideAK : strideAM;
|
|
|
|
|
int strideA1 = isARow ? strideAM : strideAK;
|
|
|
|
|
int lda = isARow ? strideAM : strideAK;
|
|
|
|
|
int aNumPtr = 8;
|
|
|
|
|
int bNumPtr = 8;
|
|
|
|
|
int NK = aShape[1];
|
|
|
|
|
|
|
|
|
|
auto shapePerCTA = getShapePerCTA(dLayout);
|
|
|
|
|
auto sizePerThread = getSizePerThread(dLayout);
|
|
|
|
|
|
|
|
|
|
Value _0 = i32_val(0);
|
|
|
|
|
|
|
|
|
|
Value mContig = i32_val(sizePerThread[order[1]]);
|
|
|
|
|
Value nContig = i32_val(sizePerThread[order[0]]);
|
|
|
|
|
|
|
|
|
|
// threadId in blocked layout
|
|
|
|
|
auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc);
|
|
|
|
|
|
|
|
|
|
Value threadIdM = threadIds[0];
|
|
|
|
|
Value threadIdN = threadIds[1];
|
|
|
|
|
|
|
|
|
|
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
|
|
|
|
|
Value offA1 = isARow ? mul(threadIdM, mContig) : _0;
|
|
|
|
|
SmallVector<Value> aOff(aNumPtr);
|
|
|
|
|
for (int i = 0; i < aNumPtr; ++i) {
|
|
|
|
|
aOff[i] = add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto aSmem =
|
|
|
|
|
ConvertTritonGPUOpToLLVMPatternBase::getSharedMemoryObjectFromStruct(
|
|
|
|
|
loc, llA, rewriter);
|
|
|
|
|
|
|
|
|
|
Type f32PtrTy = ptr_ty(f32_ty);
|
|
|
|
|
SmallVector<Value> aPtrs(aNumPtr);
|
|
|
|
|
for (int i = 0; i < aNumPtr; ++i)
|
|
|
|
|
aPtrs[i] = gep(f32PtrTy, aSmem.base, aOff[i]);
|
|
|
|
|
|
|
|
|
|
ValueTable has;
|
|
|
|
|
int M = aShape[aOrder[1]];
|
|
|
|
|
|
|
|
|
|
int mShapePerCTA = getShapePerCTAForMN(dLayout, true /*isM*/);
|
|
|
|
|
int mSizePerThread = getsizePerThreadForMN(dLayout, true /*isM*/);
|
|
|
|
|
|
|
|
|
|
for (unsigned k = 0; k < NK; ++k) {
|
|
|
|
|
for (unsigned m = 0; m < M; m += mShapePerCTA)
|
|
|
|
|
for (unsigned mm = 0; mm < mSizePerThread; ++mm)
|
|
|
|
|
if (!has.count({m + mm, k})) {
|
|
|
|
|
Value pa = gep(f32PtrTy, aPtrs[0],
|
|
|
|
|
i32_val((m + mm) * strideAM + k * strideAK));
|
|
|
|
|
Value va = load(pa);
|
|
|
|
|
has[{m + mm, k}] = va;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return getStructFromValueTable(has, rewriter, loc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value DotOpFMAConversionHelper::loadB(
|
|
|
|
|
Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, Location loc,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
|
|
|
|
|
auto bTensorTy = B.getType().cast<RankedTensorType>();
|
|
|
|
|
auto bLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
auto bShape = bTensorTy.getShape();
|
|
|
|
|
|
|
|
|
|
auto bOrder = bLayout.getOrder();
|
|
|
|
|
auto order = dLayout.getOrder();
|
|
|
|
|
|
|
|
|
|
bool isBRow = bOrder[0] == 1;
|
|
|
|
|
|
|
|
|
|
int strideBN = isBRow ? 1 : bShape[0];
|
|
|
|
|
int strideBK = isBRow ? bShape[1] : 1;
|
|
|
|
|
int strideB0 = isBRow ? strideBN : strideBK;
|
|
|
|
|
int strideB1 = isBRow ? strideBK : strideBN;
|
|
|
|
|
int ldb = isBRow ? strideBK : strideBN;
|
|
|
|
|
int bNumPtr = 8;
|
|
|
|
|
int NK = bShape[0];
|
|
|
|
|
|
|
|
|
|
auto shapePerCTA = getShapePerCTA(dLayout);
|
|
|
|
|
auto sizePerThread = getSizePerThread(dLayout);
|
|
|
|
|
|
|
|
|
|
Value _0 = i32_val(0);
|
|
|
|
|
|
|
|
|
|
Value mContig = i32_val(sizePerThread[order[1]]);
|
|
|
|
|
Value nContig = i32_val(sizePerThread[order[0]]);
|
|
|
|
|
|
|
|
|
|
// threadId in blocked layout
|
|
|
|
|
auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc);
|
|
|
|
|
Value threadIdN = threadIds[1];
|
|
|
|
|
|
|
|
|
|
Value offB0 = isBRow ? mul(threadIdN, nContig) : _0;
|
|
|
|
|
Value offB1 = isBRow ? _0 : mul(threadIdN, nContig);
|
|
|
|
|
SmallVector<Value> bOff(bNumPtr);
|
|
|
|
|
for (int i = 0; i < bNumPtr; ++i) {
|
|
|
|
|
bOff[i] = add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto bSmem =
|
|
|
|
|
ConvertTritonGPUOpToLLVMPatternBase::getSharedMemoryObjectFromStruct(
|
|
|
|
|
loc, llB, rewriter);
|
|
|
|
|
|
|
|
|
|
Type f32PtrTy = ptr_ty(f32_ty);
|
|
|
|
|
SmallVector<Value> bPtrs(bNumPtr);
|
|
|
|
|
for (int i = 0; i < bNumPtr; ++i)
|
|
|
|
|
bPtrs[i] = gep(f32PtrTy, bSmem.base, bOff[i]);
|
|
|
|
|
|
|
|
|
|
int N = bShape[bOrder[0]];
|
|
|
|
|
ValueTable hbs;
|
|
|
|
|
|
|
|
|
|
int nShapePerCTA = getShapePerCTAForMN(dLayout, false /*isM*/);
|
|
|
|
|
int nSizePerThread = getsizePerThreadForMN(dLayout, false /*isM*/);
|
|
|
|
|
|
|
|
|
|
for (unsigned k = 0; k < NK; ++k)
|
|
|
|
|
for (unsigned n = 0; n < N; n += nShapePerCTA)
|
|
|
|
|
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
|
|
|
|
|
Value pb = gep(f32PtrTy, bPtrs[0],
|
|
|
|
|
i32_val((n + nn) * strideBN + k * strideBK));
|
|
|
|
|
Value vb = load(pb);
|
|
|
|
|
hbs[{n + nn, k}] = vb;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return getStructFromValueTable(hbs, rewriter, loc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DotOpFMAConversionHelper::ValueTable
|
|
|
|
|
DotOpFMAConversionHelper::getValueTableFromStruct(
|
|
|
|
|
Value val, int K, int n0, int shapePerCTA, int sizePerThread,
|
|
|
|
|
ConversionPatternRewriter &rewriter, Location loc) const {
|
|
|
|
|
ValueTable res;
|
|
|
|
|
auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
|
|
|
|
|
loc, val, rewriter);
|
|
|
|
|
int id = 0;
|
|
|
|
|
std::set<std::pair<int, int>> keys; // ordered
|
|
|
|
|
for (unsigned k = 0; k < K; ++k) {
|
|
|
|
|
for (unsigned m = 0; m < n0; m += shapePerCTA)
|
|
|
|
|
for (unsigned mm = 0; mm < sizePerThread; ++mm) {
|
|
|
|
|
keys.insert({m + mm, k});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &key : llvm::enumerate(keys)) {
|
|
|
|
|
res[key.value()] = elems[key.index()];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
SmallVector<Value> DotOpFMAConversionHelper::getThreadIds(
|
|
|
|
|
Value threadId, ArrayRef<unsigned int> shapePerCTA,
|
|
|
|
|
ArrayRef<unsigned int> order, ConversionPatternRewriter &rewriter,
|
|
|
|
|
Location loc) const {
|
|
|
|
|
int dim = order.size();
|
|
|
|
|
SmallVector<Value> threadIds(dim);
|
|
|
|
|
for (unsigned k = 0; k < dim - 1; k++) {
|
|
|
|
|
Value dimK = i32_val(shapePerCTA[order[k]]);
|
|
|
|
|
Value rem = urem(threadId, dimK);
|
|
|
|
|
threadId = udiv(threadId, dimK);
|
|
|
|
|
threadIds[order[k]] = rem;
|
|
|
|
|
}
|
|
|
|
|
Value dimK = i32_val(shapePerCTA[order[dim - 1]]);
|
|
|
|
|
threadIds[order[dim - 1]] = urem(threadId, dimK);
|
|
|
|
|
return threadIds;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
@@ -4948,120 +5229,68 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
|
|
|
|
auto bShape = bTensorTy.getShape();
|
|
|
|
|
auto cShape = cTensorTy.getShape();
|
|
|
|
|
|
|
|
|
|
auto aLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
auto bLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
|
|
|
auto cLayout = cTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
|
|
|
auto dLayout = dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
|
|
|
|
|
|
|
|
auto aOrder = aLayout.getOrder();
|
|
|
|
|
auto bOrder = bLayout.getOrder();
|
|
|
|
|
|
|
|
|
|
ValueTable has, hbs;
|
|
|
|
|
int mShapePerCTA{-1}, nShapePerCTA{-1};
|
|
|
|
|
int mSizePerThread{-1}, nSizePerThread{-1};
|
|
|
|
|
ArrayRef<unsigned> aOrder, bOrder;
|
|
|
|
|
Value llA, llB;
|
|
|
|
|
BlockedEncodingAttr dLayout =
|
|
|
|
|
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
|
|
|
|
auto order = dLayout.getOrder();
|
|
|
|
|
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
|
|
|
|
|
|
|
|
|
bool isARow = aOrder[0] == 1;
|
|
|
|
|
bool isBRow = bOrder[0] == 1;
|
|
|
|
|
DotOpFMAConversionHelper helper(dLayout);
|
|
|
|
|
if (auto aDotOpLayout =
|
|
|
|
|
aTensorTy.getEncoding()
|
|
|
|
|
.dyn_cast<DotOperandEncodingAttr>()) { // get input from
|
|
|
|
|
// convert_layout
|
|
|
|
|
auto bDotOpLayout =
|
|
|
|
|
bTensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
|
|
|
|
|
auto aLayout = aDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
|
|
|
|
auto bLayout = bDotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
|
|
|
|
|
|
|
|
|
int strideAM = isARow ? aShape[1] : 1;
|
|
|
|
|
int strideAK = isARow ? 1 : aShape[0];
|
|
|
|
|
int strideBN = isBRow ? 1 : bShape[0];
|
|
|
|
|
int strideBK = isBRow ? bShape[1] : 1;
|
|
|
|
|
int strideA0 = isARow ? strideAK : strideAM;
|
|
|
|
|
int strideA1 = isARow ? strideAM : strideAK;
|
|
|
|
|
int strideB0 = isBRow ? strideBN : strideBK;
|
|
|
|
|
int strideB1 = isBRow ? strideBK : strideBN;
|
|
|
|
|
int lda = isARow ? strideAM : strideAK;
|
|
|
|
|
int ldb = isBRow ? strideBK : strideBN;
|
|
|
|
|
int aPerPhase = aLayout.getPerPhase();
|
|
|
|
|
int aMaxPhase = aLayout.getMaxPhase();
|
|
|
|
|
int bPerPhase = bLayout.getPerPhase();
|
|
|
|
|
int bMaxPhase = bLayout.getMaxPhase();
|
|
|
|
|
int aNumPtr = 8;
|
|
|
|
|
int bNumPtr = 8;
|
|
|
|
|
int NK = aShape[1];
|
|
|
|
|
|
|
|
|
|
auto shapePerCTA = getShapePerCTA(dLayout);
|
|
|
|
|
assert(bLayout);
|
|
|
|
|
llA = adaptor.a();
|
|
|
|
|
llB = adaptor.b();
|
|
|
|
|
} else if (auto aLayout =
|
|
|
|
|
aTensorTy.getEncoding()
|
|
|
|
|
.dyn_cast<SharedEncodingAttr>()) { // load input from smem
|
|
|
|
|
auto bLayout = bTensorTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
|
|
|
|
assert(bLayout);
|
|
|
|
|
Value thread = getThreadId(rewriter, loc);
|
|
|
|
|
llA = helper.loadA(A, adaptor.a(), dLayout, thread, loc, rewriter);
|
|
|
|
|
llB = helper.loadB(B, adaptor.b(), dLayout, thread, loc, rewriter);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto sizePerThread = getSizePerThread(dLayout);
|
|
|
|
|
auto shapePerCTA = getShapePerCTA(dLayout);
|
|
|
|
|
|
|
|
|
|
Value _0 = i32_val(0);
|
|
|
|
|
int K = aShape[1];
|
|
|
|
|
int M = aShape[0];
|
|
|
|
|
int N = bShape[1];
|
|
|
|
|
|
|
|
|
|
Value mContig = i32_val(sizePerThread[order[1]]);
|
|
|
|
|
Value nContig = i32_val(sizePerThread[order[0]]);
|
|
|
|
|
mShapePerCTA = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
|
|
|
|
mSizePerThread =
|
|
|
|
|
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
|
|
|
|
nShapePerCTA = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
|
|
|
|
|
nSizePerThread =
|
|
|
|
|
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
|
|
|
|
|
|
|
|
|
// threadId in blocked layout
|
|
|
|
|
SmallVector<Value> threadIds;
|
|
|
|
|
{
|
|
|
|
|
int dim = cShape.size();
|
|
|
|
|
threadIds.resize(dim);
|
|
|
|
|
for (unsigned k = 0; k < dim - 1; k++) {
|
|
|
|
|
Value dimK = i32_val(shapePerCTA[order[k]]);
|
|
|
|
|
Value rem = urem(threadId, dimK);
|
|
|
|
|
threadId = udiv(threadId, dimK);
|
|
|
|
|
threadIds[order[k]] = rem;
|
|
|
|
|
}
|
|
|
|
|
Value dimK = i32_val(shapePerCTA[order[dim - 1]]);
|
|
|
|
|
threadIds[order[dim - 1]] = urem(threadId, dimK);
|
|
|
|
|
}
|
|
|
|
|
has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA, mSizePerThread,
|
|
|
|
|
rewriter, loc);
|
|
|
|
|
hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA, nSizePerThread,
|
|
|
|
|
rewriter, loc);
|
|
|
|
|
|
|
|
|
|
Value threadIdM = threadIds[0];
|
|
|
|
|
Value threadIdN = threadIds[1];
|
|
|
|
|
|
|
|
|
|
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
|
|
|
|
|
Value offA1 = isARow ? mul(threadIdM, mContig) : _0;
|
|
|
|
|
SmallVector<Value> aOff(aNumPtr);
|
|
|
|
|
for (int i = 0; i < aNumPtr; ++i) {
|
|
|
|
|
aOff[i] = add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value offB0 = isBRow ? mul(threadIdN, nContig) : _0;
|
|
|
|
|
Value offB1 = isBRow ? _0 : mul(threadIdN, nContig);
|
|
|
|
|
SmallVector<Value> bOff(bNumPtr);
|
|
|
|
|
for (int i = 0; i < bNumPtr; ++i) {
|
|
|
|
|
bOff[i] = add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto aSmem = getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter);
|
|
|
|
|
auto bSmem = getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter);
|
|
|
|
|
|
|
|
|
|
Type f32PtrTy = ptr_ty(f32_ty);
|
|
|
|
|
SmallVector<Value> aPtrs(aNumPtr);
|
|
|
|
|
for (int i = 0; i < aNumPtr; ++i)
|
|
|
|
|
aPtrs[i] = gep(f32PtrTy, aSmem.base, aOff[i]);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> bPtrs(bNumPtr);
|
|
|
|
|
for (int i = 0; i < bNumPtr; ++i)
|
|
|
|
|
bPtrs[i] = gep(f32PtrTy, bSmem.base, bOff[i]);
|
|
|
|
|
|
|
|
|
|
ValueTable has, hbs;
|
|
|
|
|
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
|
|
|
|
SmallVector<Value> ret = cc;
|
|
|
|
|
// is this compatible with blocked layout?
|
|
|
|
|
|
|
|
|
|
for (unsigned k = 0; k < NK; k++) {
|
|
|
|
|
for (unsigned k = 0; k < K; k++) {
|
|
|
|
|
int z = 0;
|
|
|
|
|
for (unsigned i = 0; i < cShape[order[1]]; i += shapePerCTA[order[1]])
|
|
|
|
|
for (unsigned j = 0; j < cShape[order[0]]; j += shapePerCTA[order[0]])
|
|
|
|
|
for (unsigned ii = 0; ii < sizePerThread[order[1]]; ++ii)
|
|
|
|
|
for (unsigned jj = 0; jj < sizePerThread[order[0]]; ++jj) {
|
|
|
|
|
unsigned m = order[0] == 1 ? i : j;
|
|
|
|
|
unsigned n = order[0] == 1 ? j : i;
|
|
|
|
|
unsigned mm = order[0] == 1 ? ii : jj;
|
|
|
|
|
unsigned nn = order[0] == 1 ? jj : ii;
|
|
|
|
|
if (!has.count({m + mm, k})) {
|
|
|
|
|
Value pa = gep(f32PtrTy, aPtrs[0],
|
|
|
|
|
i32_val((m + mm) * strideAM + k * strideAK));
|
|
|
|
|
Value va = load(pa);
|
|
|
|
|
has[{m + mm, k}] = va;
|
|
|
|
|
}
|
|
|
|
|
if (!hbs.count({n + nn, k})) {
|
|
|
|
|
Value pb = gep(f32PtrTy, bPtrs[0],
|
|
|
|
|
i32_val((n + nn) * strideBN + k * strideBK));
|
|
|
|
|
Value vb = load(pb);
|
|
|
|
|
hbs[{n + nn, k}] = vb;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (unsigned m = 0; m < M; m += mShapePerCTA)
|
|
|
|
|
for (unsigned n = 0; n < N; n += nShapePerCTA)
|
|
|
|
|
for (unsigned mm = 0; mm < mSizePerThread; ++mm)
|
|
|
|
|
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
|
|
|
|
|
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
|
|
|
|
|
hbs[{n + nn, k}], ret[z]);
|
|
|
|
|
|
|
|
|
|
++z;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@@ -5138,6 +5367,13 @@ public:
|
|
|
|
|
auto ctx = type.getContext();
|
|
|
|
|
Attribute layout = type.getEncoding();
|
|
|
|
|
auto shape = type.getShape();
|
|
|
|
|
|
|
|
|
|
// TODO[Keren, Superjomn]: fix it, allowTF32 is not accessible here for it
|
|
|
|
|
// is bound to an Op instance.
|
|
|
|
|
bool allowTF32 = false;
|
|
|
|
|
bool isFMADot = type.getElementType().isF32() && !allowTF32 &&
|
|
|
|
|
layout.dyn_cast_or_null<DotOperandEncodingAttr>();
|
|
|
|
|
|
|
|
|
|
if (layout &&
|
|
|
|
|
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
|
|
|
|
layout.isa<MmaEncodingAttr>())) {
|
|
|
|
@@ -5158,65 +5394,55 @@ public:
|
|
|
|
|
types.push_back(IntegerType::get(ctx, 32));
|
|
|
|
|
}
|
|
|
|
|
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
|
|
|
|
} else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
|
|
|
|
|
if (mmaLayout.getVersion() == 2) {
|
|
|
|
|
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
|
|
|
|
|
size_t fcSize = 4 * repM * repN;
|
|
|
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
|
|
|
ctx, SmallVector<Type>(fcSize, convertType(type.getElementType())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (mmaLayout.getVersion() == 1) {
|
|
|
|
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
|
|
|
|
int repM = helper.getRepM(shape[0]);
|
|
|
|
|
int repN = helper.getRepN(shape[1]);
|
|
|
|
|
int elems = 8 * repM * repN;
|
|
|
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
|
|
|
ctx, SmallVector<Type>(elems, convertType(type.getElementType())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llvm::errs()
|
|
|
|
|
<< "Unexpected mma layout detected in TritonToLLVMTypeConverter";
|
|
|
|
|
return llvm::None;
|
|
|
|
|
|
|
|
|
|
} else if (auto dot_op_layout =
|
|
|
|
|
} else if (auto dotOpLayout =
|
|
|
|
|
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
|
|
|
|
|
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
|
|
|
|
|
auto wpt = mmaLayout.getWarpsPerCTA();
|
|
|
|
|
Type elemTy = convertType(type.getElementType());
|
|
|
|
|
auto vecSize = 1;
|
|
|
|
|
if (elemTy.getIntOrFloatBitWidth() == 16) {
|
|
|
|
|
vecSize = 2;
|
|
|
|
|
} else if (elemTy.getIntOrFloatBitWidth() == 8) {
|
|
|
|
|
vecSize = 4;
|
|
|
|
|
} else {
|
|
|
|
|
assert(false && "Unsupported element type");
|
|
|
|
|
}
|
|
|
|
|
Type vecTy = vec_ty(elemTy, vecSize);
|
|
|
|
|
if (mmaLayout.getVersion() == 2) {
|
|
|
|
|
if (dot_op_layout.getOpIdx() == 0) { // $a
|
|
|
|
|
int elems =
|
|
|
|
|
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt);
|
|
|
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
|
|
|
ctx, SmallVector<Type>(elems, vecTy));
|
|
|
|
|
}
|
|
|
|
|
if (dot_op_layout.getOpIdx() == 1) { // $b
|
|
|
|
|
int elems =
|
|
|
|
|
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt);
|
|
|
|
|
return struct_ty(SmallVector<Type>(elems, vecTy));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (isFMADot) { // for parent is blocked layout
|
|
|
|
|
int numElemsPerThread =
|
|
|
|
|
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
|
|
|
|
|
|
|
|
|
|
if (mmaLayout.getVersion() == 1) {
|
|
|
|
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
|
|
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
|
|
|
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
|
|
|
|
|
|
|
|
|
|
if (dot_op_layout.getOpIdx() == 0) { // $a
|
|
|
|
|
int elems = helper.numElemsPerThreadA(type);
|
|
|
|
|
return struct_ty(SmallVector<Type>(elems, vecTy));
|
|
|
|
|
} else { // for parent is MMA layout
|
|
|
|
|
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
|
|
|
|
|
auto wpt = mmaLayout.getWarpsPerCTA();
|
|
|
|
|
Type elemTy = convertType(type.getElementType());
|
|
|
|
|
auto vecSize = 1;
|
|
|
|
|
if (elemTy.getIntOrFloatBitWidth() == 16) {
|
|
|
|
|
vecSize = 2;
|
|
|
|
|
} else if (elemTy.getIntOrFloatBitWidth() == 8) {
|
|
|
|
|
vecSize = 4;
|
|
|
|
|
} else {
|
|
|
|
|
assert(false && "Unsupported element type");
|
|
|
|
|
}
|
|
|
|
|
if (dot_op_layout.getOpIdx() == 1) { // $b
|
|
|
|
|
int elems = helper.numElemsPerThreadB(type);
|
|
|
|
|
return struct_ty(SmallVector<Type>(elems, vecTy));
|
|
|
|
|
Type vecTy = vec_ty(elemTy, vecSize);
|
|
|
|
|
if (mmaLayout.getVersion() == 2) {
|
|
|
|
|
if (dotOpLayout.getOpIdx() == 0) { // $a
|
|
|
|
|
int elems =
|
|
|
|
|
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt);
|
|
|
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
|
|
|
ctx, SmallVector<Type>(elems, vecTy));
|
|
|
|
|
}
|
|
|
|
|
if (dotOpLayout.getOpIdx() == 1) { // $b
|
|
|
|
|
int elems =
|
|
|
|
|
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt);
|
|
|
|
|
return struct_ty(SmallVector<Type>(elems, vecTy));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (mmaLayout.getVersion() == 1) {
|
|
|
|
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
|
|
|
|
|
|
|
|
|
if (dotOpLayout.getOpIdx() == 0) { // $a
|
|
|
|
|
int elems = helper.numElemsPerThreadA(type);
|
|
|
|
|
Type x2Ty = vec_ty(elemTy, 2);
|
|
|
|
|
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
|
|
|
|
}
|
|
|
|
|
if (dotOpLayout.getOpIdx() == 1) { // $b
|
|
|
|
|
int elems = helper.numElemsPerThreadB(type);
|
|
|
|
|
Type x2Ty = vec_ty(elemTy, 2);
|
|
|
|
|
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|