[BACKEND] Support splat constant on the DotOperandLayout (#1008)
This commit is contained in:
@@ -4,11 +4,13 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
|
using ::mlir::LLVM::DotOpFMAConversionHelper;
|
||||||
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
||||||
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
|
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
|
||||||
using ::mlir::LLVM::getElementsFromStruct;
|
using ::mlir::LLVM::getElementsFromStruct;
|
||||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||||
using ::mlir::LLVM::getStructFromElements;
|
using ::mlir::LLVM::getStructFromElements;
|
||||||
|
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||||
using ::mlir::triton::gpu::getElemsPerThread;
|
using ::mlir::triton::gpu::getElemsPerThread;
|
||||||
|
|
||||||
struct SplatOpConversion
|
struct SplatOpConversion
|
||||||
@@ -38,6 +40,11 @@ struct SplatOpConversion
|
|||||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||||
|
|
||||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||||
|
} else if (auto dotLayout =
|
||||||
|
tensorTy.getEncoding()
|
||||||
|
.dyn_cast<triton::gpu::DotOperandEncodingAttr>()) {
|
||||||
|
return convertSplatLikeOpWithDotOperandLayout(
|
||||||
|
dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
|
||||||
} else if (auto mmaLayout =
|
} else if (auto mmaLayout =
|
||||||
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
|
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
|
||||||
return convertSplatLikeOpWithMmaLayout(
|
return convertSplatLikeOpWithMmaLayout(
|
||||||
@@ -48,6 +55,38 @@ struct SplatOpConversion
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value convertSplatLikeOpWithDotOperandLayout(
|
||||||
|
const triton::gpu::DotOperandEncodingAttr &layout, Type resType,
|
||||||
|
Type elemType, Value constVal, TypeConverter *typeConverter,
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc) {
|
||||||
|
auto tensorTy = resType.cast<RankedTensorType>();
|
||||||
|
auto shape = tensorTy.getShape();
|
||||||
|
auto parent = layout.getParent();
|
||||||
|
int numElems{};
|
||||||
|
if (auto mmaLayout = parent.dyn_cast<MmaEncodingAttr>()) {
|
||||||
|
if (mmaLayout.isAmpere()) {
|
||||||
|
numElems = layout.getOpIdx() == 0
|
||||||
|
? MMA16816ConversionHelper::getANumElemsPerThread(
|
||||||
|
tensorTy, mmaLayout.getWarpsPerCTA()[0])
|
||||||
|
: MMA16816ConversionHelper::getBNumElemsPerThread(
|
||||||
|
tensorTy, mmaLayout.getWarpsPerCTA()[1]);
|
||||||
|
} else if (mmaLayout.isVolta()) {
|
||||||
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||||
|
numElems = layout.getOpIdx() == 0
|
||||||
|
? helper.numElemsPerThreadA(shape, {0, 1})
|
||||||
|
: helper.numElemsPerThreadB(shape, {0, 1});
|
||||||
|
}
|
||||||
|
} else if (auto blockedLayout = parent.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
|
numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout);
|
||||||
|
} else {
|
||||||
|
assert(false && "Unsupported layout found");
|
||||||
|
}
|
||||||
|
auto structTy = LLVM::LLVMStructType::getLiteral(
|
||||||
|
rewriter.getContext(), SmallVector<Type>(numElems, elemType));
|
||||||
|
return getStructFromElements(loc, SmallVector<Value>(numElems, constVal),
|
||||||
|
rewriter, structTy);
|
||||||
|
}
|
||||||
|
|
||||||
static Value convertSplatLikeOpWithMmaLayout(
|
static Value convertSplatLikeOpWithMmaLayout(
|
||||||
const MmaEncodingAttr &layout, Type resType, Type elemType,
|
const MmaEncodingAttr &layout, Type resType, Type elemType,
|
||||||
Value constVal, TypeConverter *typeConverter,
|
Value constVal, TypeConverter *typeConverter,
|
||||||
|
@@ -1227,20 +1227,20 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
|||||||
elif dtype == 'int8':
|
elif dtype == 'int8':
|
||||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||||
|
|
||||||
# FIXME: Unsupported layout found in ConvertSplatLikeOp
|
|
||||||
# def test_dot_without_load():
|
def test_dot_without_load():
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(out):
|
def kernel(out):
|
||||||
# pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
# a = tl.zeros((32, 32), tl.float32)
|
a = tl.zeros((32, 32), tl.float32)
|
||||||
# b = tl.zeros((32, 32), tl.float32)
|
b = tl.zeros((32, 32), tl.float32)
|
||||||
# c = tl.zeros((32, 32), tl.float32)
|
c = tl.zeros((32, 32), tl.float32)
|
||||||
# c = tl.dot(a, b)
|
c = tl.dot(a, b)
|
||||||
# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||||
# tl.store(pout, c)
|
tl.store(pout, c)
|
||||||
#
|
|
||||||
# out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
||||||
# kernel[(1,)](out)
|
kernel[(1,)](out)
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test arange
|
# test arange
|
||||||
|
Reference in New Issue
Block a user