[BACKEND] Support splat constant on the DotOperandLayout (#1008)
This commit is contained in:
@@ -4,11 +4,13 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::DotOpFMAConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
|
||||
using ::mlir::LLVM::getElementsFromStruct;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
|
||||
struct SplatOpConversion
|
||||
@@ -38,6 +40,11 @@ struct SplatOpConversion
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
} else if (auto dotLayout =
|
||||
tensorTy.getEncoding()
|
||||
.dyn_cast<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
return convertSplatLikeOpWithDotOperandLayout(
|
||||
dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
|
||||
} else if (auto mmaLayout =
|
||||
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
|
||||
return convertSplatLikeOpWithMmaLayout(
|
||||
@@ -48,6 +55,38 @@ struct SplatOpConversion
|
||||
return {};
|
||||
}
|
||||
|
||||
static Value convertSplatLikeOpWithDotOperandLayout(
|
||||
const triton::gpu::DotOperandEncodingAttr &layout, Type resType,
|
||||
Type elemType, Value constVal, TypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
auto parent = layout.getParent();
|
||||
int numElems{};
|
||||
if (auto mmaLayout = parent.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isAmpere()) {
|
||||
numElems = layout.getOpIdx() == 0
|
||||
? MMA16816ConversionHelper::getANumElemsPerThread(
|
||||
tensorTy, mmaLayout.getWarpsPerCTA()[0])
|
||||
: MMA16816ConversionHelper::getBNumElemsPerThread(
|
||||
tensorTy, mmaLayout.getWarpsPerCTA()[1]);
|
||||
} else if (mmaLayout.isVolta()) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
numElems = layout.getOpIdx() == 0
|
||||
? helper.numElemsPerThreadA(shape, {0, 1})
|
||||
: helper.numElemsPerThreadB(shape, {0, 1});
|
||||
}
|
||||
} else if (auto blockedLayout = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||
numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout);
|
||||
} else {
|
||||
assert(false && "Unsupported layout found");
|
||||
}
|
||||
auto structTy = LLVM::LLVMStructType::getLiteral(
|
||||
rewriter.getContext(), SmallVector<Type>(numElems, elemType));
|
||||
return getStructFromElements(loc, SmallVector<Value>(numElems, constVal),
|
||||
rewriter, structTy);
|
||||
}
|
||||
|
||||
static Value convertSplatLikeOpWithMmaLayout(
|
||||
const MmaEncodingAttr &layout, Type resType, Type elemType,
|
||||
Value constVal, TypeConverter *typeConverter,
|
||||
|
Reference in New Issue
Block a user