[Triton-MLIR][BACKEND] Support $c from mma layout in dot (#798)
This PR does 1. Support the case where $c holding a mma layout, this should be useful in forloop in k-axis in GEMM 2. Fix the `unrealized_conversion_cast` in ConvertLayout[shared->dot_op] Known issue 1. There is some IO conflict in GEMM with a k-forloop, it is temporarily solved by [adding a barrier](https://github.com/openai/triton/pull/798/files#diff-8a9a5a7f4a025fb1299af29d190d5626bd9000406d3ea47c49679272d3d6abe9R3028) in dot conversion, but we are still working on it, will get a more generic fix for it in the following PR. 2. The parallel pass will result in a buggy instruction result type ```mlir %1049 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.commit_group ;", "" : () -> !llvm.void %1050 = builtin.unrealized_conversion_cast %1049 : !llvm.void to !llvm.ptr<f16, 3> ``` So we temporarily disable it.
This commit is contained in:
@@ -622,6 +622,13 @@ protected:
|
||||
Value smem;
|
||||
};
|
||||
|
||||
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
|
||||
Type resType, Type elemType,
|
||||
Value constVal,
|
||||
TypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc);
|
||||
|
||||
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
|
||||
// LLVM::StructType value.
|
||||
//
|
||||
@@ -632,16 +639,26 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
TypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bitcast(srcType, constVal);
|
||||
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
|
||||
llvm::SmallVector<Value, 4> elems(elemsPerThread, llSrc);
|
||||
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = bitcast(srcType, constVal);
|
||||
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
|
||||
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
|
||||
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
} else if (auto mmaLayout =
|
||||
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
|
||||
return convertSplatLikeOpWithMmaLayout(
|
||||
mmaLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
|
||||
} else
|
||||
assert(false && "Unsupported layout found in ConvertSplatLikeOp");
|
||||
|
||||
return Value{};
|
||||
}
|
||||
|
||||
struct SplatOpConversion
|
||||
@@ -2436,8 +2453,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
MLIRContext *ctx = op->getContext();
|
||||
bool allowTF32 = op.allowTF32();
|
||||
|
||||
assert(isSplatLike(C) && "Currently only splat-like C is supported now");
|
||||
|
||||
// Here we assume the DotOp's operands always comes from shared memory.
|
||||
auto AShape = A.getType().cast<RankedTensorType>().getShape();
|
||||
size_t reduceAxis = 1;
|
||||
@@ -2536,6 +2551,31 @@ struct DotOpConversionHelper {
|
||||
mmaType = getTensorCoreTypeFromOperand(operandTy);
|
||||
}
|
||||
|
||||
// Get the M and N of mat instruction shape.
|
||||
static std::tuple<int, int> getMatShapeMN() {
|
||||
// According to DotOpConversionHelper::mmaMatShape, all the matrix shape's
|
||||
// M,N are {8,8}
|
||||
return {8, 8};
|
||||
}
|
||||
|
||||
// Get the M and N of mma instruction shape.
|
||||
static std::tuple<int, int> getInstrShapeMN() {
|
||||
// According to DotOpConversionHelper::mmaInstrShape, all the M,N are {16,8}
|
||||
return {16, 8};
|
||||
}
|
||||
|
||||
static std::tuple<int, int> getRepMN(const RankedTensorType &tensorTy) {
|
||||
auto mmaLayout = tensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
int M = tensorTy.getShape()[0];
|
||||
int N = tensorTy.getShape()[1];
|
||||
auto [instrM, instrN] = getInstrShapeMN();
|
||||
int repM = std::max<int>(M / (wpt[0] * instrM), 1);
|
||||
int repN = std::max<int>(N / (wpt[1] * instrN), 1);
|
||||
return {repM, repN};
|
||||
}
|
||||
|
||||
Type getShemPtrTy() const {
|
||||
switch (mmaType) {
|
||||
case TensorCoreType::FP32_FP16_FP16_FP32:
|
||||
@@ -2633,15 +2673,20 @@ struct DotOpConversionHelper {
|
||||
return mmaInstrShape.at(mmaType);
|
||||
}
|
||||
|
||||
static ArrayRef<int> getMmaInstrShape(TensorCoreType tensorCoreType) {
|
||||
assert(tensorCoreType != TensorCoreType::NOT_APPLICABLE &&
|
||||
"Unknown mma type found.");
|
||||
return mmaInstrShape.at(tensorCoreType);
|
||||
}
|
||||
|
||||
ArrayRef<int> getMmaMatShape() const {
|
||||
assert(mmaType != TensorCoreType::NOT_APPLICABLE &&
|
||||
"Unknown mma type found.");
|
||||
return mmaMatShape.at(mmaType);
|
||||
}
|
||||
|
||||
// Deduce the TensorCoreType from either $a or $b's type. This method is not
|
||||
// safe, but we cannot get the DotOp in some getmaMatShape usage case.
|
||||
TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) const {
|
||||
// Deduce the TensorCoreType from either $a or $b's type.
|
||||
static TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) {
|
||||
auto tensorTy = operandTy.cast<RankedTensorType>();
|
||||
auto elemTy = tensorTy.getElementType();
|
||||
if (elemTy.isF16())
|
||||
@@ -2814,22 +2859,58 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
// \param operand is either $a or $b's type.
|
||||
inline int getNumRepM(Type operand, int M) const {
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
||||
return std::max<int>(M / (wpt[0] * mmaInstrM), 1);
|
||||
return getNumRepM(operand, M, wpt[0]);
|
||||
}
|
||||
|
||||
// \param operand is either $a or $b's type.
|
||||
inline int getNumRepN(Type operand, int N) const {
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
||||
return std::max<int>(N / (wpt[1] * mmaInstrN), 1);
|
||||
return getNumRepN(operand, N, wpt[1]);
|
||||
}
|
||||
|
||||
// \param operand is either $a or $b's type.
|
||||
inline int getNumRepK(Type operand, int K) const {
|
||||
auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand);
|
||||
return getNumRepK_(operand, K);
|
||||
}
|
||||
|
||||
static int getNumRepM(Type operand, int M, int wpt) {
|
||||
auto tensorCoreType =
|
||||
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
|
||||
int mmaInstrM = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[0];
|
||||
return std::max<int>(M / (wpt * mmaInstrM), 1);
|
||||
}
|
||||
|
||||
static int getNumRepN(Type operand, int N, int wpt) {
|
||||
auto tensorCoreType =
|
||||
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
|
||||
int mmaInstrN = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[1];
|
||||
return std::max<int>(N / (wpt * mmaInstrN), 1);
|
||||
}
|
||||
|
||||
static int getNumRepK_(Type operand, int K) {
|
||||
auto tensorCoreType =
|
||||
DotOpConversionHelper::getTensorCoreTypeFromOperand(operand);
|
||||
int mmaInstrK = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[2];
|
||||
return std::max<int>(K / mmaInstrK, 1);
|
||||
}
|
||||
|
||||
// Get number of elements per thread for $a operand.
|
||||
static size_t getANumElemsPerThread(RankedTensorType operand,
|
||||
ArrayRef<unsigned> wpt) {
|
||||
auto shape = operand.getShape();
|
||||
int repM = getNumRepM(operand, shape[0], wpt[0]);
|
||||
int repK = getNumRepK_(operand, shape[1]);
|
||||
return 4 * repM * repK;
|
||||
}
|
||||
|
||||
// Get number of elements per thread for $b operand.
|
||||
static size_t getBNumElemsPerThread(RankedTensorType operand,
|
||||
ArrayRef<unsigned> wpt) {
|
||||
auto shape = operand.getShape();
|
||||
int repK = getNumRepK_(operand, shape[0]);
|
||||
int repN = getNumRepN(operand, shape[1], wpt[1]);
|
||||
return 4 * std::max(repN / 2, 1) * repK;
|
||||
}
|
||||
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA(Value tensor, Value llTensor) const {
|
||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
@@ -2863,9 +2944,6 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
// step2. Format the values to LLVM::Struct to passing to mma codegen.
|
||||
Value result = composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||
|
||||
// TODO[Superjomn]: Replace the convert_layout op with the result once the
|
||||
// DotOperandEncodingAttr is ready.
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -2894,15 +2972,21 @@ struct MMA16816ConversionHelper {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Loading $c from smem(?) to registers, returns a Value.
|
||||
// NOTE Only SplatLike tensor is supported now.
|
||||
Value loadC(Value tensor) const {
|
||||
// Currently, we only support a SplatLike C. For the other cases, e.g., C in
|
||||
// shared layout or blocked layout, we will support them by expanding
|
||||
// convert_layout.
|
||||
auto hc = helper.loadSplatLikeC(tensor, loc, rewriter);
|
||||
assert(hc.size() == 4UL && "Only splat-like C is supported now");
|
||||
return hc[0];
|
||||
// Loading $c to registers, returns a Value.
|
||||
Value loadC(Value tensor, Value llTensor) const {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy);
|
||||
size_t fcSize = 4 * repM * repN;
|
||||
|
||||
assert(tensorTy.getEncoding().isa<MmaEncodingAttr>() &&
|
||||
"Currently, we only support $c with a mma layout.");
|
||||
// Load a normal C tensor with mma layout, that should be a
|
||||
// LLVM::struct with fcSize elements.
|
||||
auto structTy = llTensor.getType().cast<LLVM::LLVMStructType>();
|
||||
assert(structTy.getBody().size() == fcSize &&
|
||||
"DotOp's $c operand should pass the same number of values as $d in "
|
||||
"mma layout.");
|
||||
return llTensor;
|
||||
}
|
||||
|
||||
// Conduct the Dot conversion.
|
||||
@@ -2934,9 +3018,8 @@ struct MMA16816ConversionHelper {
|
||||
getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK);
|
||||
ValueTable hb = getValuesFromDotOperandLayoutStruct(
|
||||
loadedB, std::max(numRepN / 2, 1), numRepK);
|
||||
|
||||
const int fcSize = 4 * numRepM * numRepN;
|
||||
SmallVector<Value> fc(fcSize, loadedC);
|
||||
auto fc = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct(
|
||||
loc, loadedC, rewriter);
|
||||
|
||||
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
unsigned colsPerThread = numRepN * 2;
|
||||
@@ -2974,6 +3057,11 @@ struct MMA16816ConversionHelper {
|
||||
for (unsigned n = 0; n < numRepN; ++n)
|
||||
callMma(2 * m, n, 2 * k);
|
||||
|
||||
// NOTE, the barrier here is a temporary trick making the gemm with a
|
||||
// k-forloop pass the precision test, or it will fail.
|
||||
// TODO[Superjomn]: Fix with a more general and performance-friendly way.
|
||||
barrier;
|
||||
|
||||
// replace with new packed result
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(fc.size(), type::f32Ty(ctx)));
|
||||
@@ -3123,9 +3211,6 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||
// operand $b
|
||||
res = mmaHelper.loadB(src, adaptor.src());
|
||||
} else if (dotOperandLayout.getOpIdx() == 2) {
|
||||
// operand $c
|
||||
res = mmaHelper.loadC(src);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
@@ -3163,10 +3248,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
||||
loadedB = mmaHelper.loadB(op.b(), adaptor.b());
|
||||
}
|
||||
|
||||
// TODO[Superjomn]: Process C as a mma layout.
|
||||
// Currently, C is simply treated as a Splat Op, and the data layout is not
|
||||
// mattered.
|
||||
loadedC = mmaHelper.loadC(op.c());
|
||||
loadedC = mmaHelper.loadC(op.c(), adaptor.c());
|
||||
|
||||
return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op,
|
||||
adaptor);
|
||||
@@ -3174,6 +3256,26 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
||||
|
||||
/// ====================== mma codegen end ============================
|
||||
|
||||
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
|
||||
Type resType, Type elemType,
|
||||
Value constVal,
|
||||
TypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) {
|
||||
if (layout.getVersion() == 2) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy);
|
||||
size_t fcSize = 4 * repM * repN;
|
||||
|
||||
auto structTy = LLVM::LLVMStructType::getLiteral(
|
||||
rewriter.getContext(), SmallVector<Type>(fcSize, elemType));
|
||||
return getStructFromElements(loc, SmallVector<Value>(fcSize, constVal),
|
||||
rewriter, structTy);
|
||||
}
|
||||
|
||||
assert(false && "Unsupported mma layout found");
|
||||
}
|
||||
|
||||
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
|
||||
public:
|
||||
using TypeConverter::convertType;
|
||||
@@ -3199,6 +3301,7 @@ public:
|
||||
}
|
||||
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||
auto ctx = type.getContext();
|
||||
Attribute layout = type.getEncoding();
|
||||
if (layout &&
|
||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||
@@ -3207,11 +3310,50 @@ public:
|
||||
getElemsPerThread(layout, type.getShape());
|
||||
SmallVector<Type, 4> types(numElementsPerThread,
|
||||
convertType(type.getElementType()));
|
||||
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto shared_layout =
|
||||
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||
} else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
auto [repM, repN] = DotOpConversionHelper::getRepMN(type);
|
||||
size_t fcSize = 4 * repM * repN;
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(fcSize, type.getElementType()));
|
||||
}
|
||||
|
||||
llvm::errs()
|
||||
<< "Unexpected mma layout detected in TritonToLLVMTypeConverter";
|
||||
return llvm::None;
|
||||
|
||||
} else if (auto dot_op_layout =
|
||||
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
|
||||
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
Type elemTy = type.getElementType();
|
||||
|
||||
if (dot_op_layout.getOpIdx() == 0) { // $a
|
||||
int elems =
|
||||
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
if (dot_op_layout.getOpIdx() == 1) { // $b
|
||||
int elems =
|
||||
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
}
|
||||
|
||||
llvm::errs() << "Unexpected dot operand layout detected in "
|
||||
"TritonToLLVMTypeConverter";
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
return llvm::None;
|
||||
}
|
||||
};
|
||||
|
@@ -35,6 +35,9 @@ def matmul_no_scf_kernel(
|
||||
[256, 128, 16, 4],
|
||||
[128, 16, 32, 4],
|
||||
[32, 128, 64, 4],
|
||||
[128, 128, 64, 4],
|
||||
[64, 128, 128, 4],
|
||||
[64, 128, 128, 2],
|
||||
])
|
||||
def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
||||
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
@@ -78,24 +81,39 @@ def matmul_kernel(
|
||||
tl.store(c_ptrs, accumulator)
|
||||
|
||||
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
|
||||
# @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
|
||||
# [128, 256, 128, 4, 128, 256, 32],
|
||||
# # [256, 128, 64, 4, 256, 128, 16],
|
||||
# # [128, 16, 128, 4, 128, 16, 32],
|
||||
# # [32, 128, 256, 4, 32, 128, 64],
|
||||
# ])
|
||||
# def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
|
||||
# a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
# b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||
# c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||
# grid = lambda META: (1, )
|
||||
# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
# stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
# stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
# stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
# M=a.shape[0], N=b.shape[1], K=a.shape[1],
|
||||
# BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
# num_warps=NUM_WARPS)
|
||||
# golden = torch.matmul(a, b)
|
||||
# torch.set_printoptions(profile="full")
|
||||
# assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
|
||||
# Non-forloop
|
||||
[64, 32, 64, 4, 64, 32, 64],
|
||||
[128, 64, 128, 4, 128, 64, 128],
|
||||
# K-Forloop
|
||||
[64, 32, 128, 4, 64, 32, 64],
|
||||
[128, 16, 128, 4, 128, 16, 32],
|
||||
[32, 16, 128, 4, 32, 16, 32],
|
||||
[32, 64, 128, 4, 32, 64, 32],
|
||||
[32, 128, 256, 4, 32, 128, 64],
|
||||
[64, 128, 64, 4, 64, 128, 32],
|
||||
[128, 128, 64, 4, 128, 128, 32],
|
||||
[64, 64, 128, 4, 64, 64, 32],
|
||||
[128, 128, 128, 4, 128, 128, 32],
|
||||
[128, 128, 256, 4, 128, 128, 64],
|
||||
[128, 256, 128, 4, 128, 256, 32],
|
||||
[256, 128, 64, 4, 256, 128, 16],
|
||||
[128, 64, 128, 4, 128, 64, 32],
|
||||
])
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
|
||||
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||
grid = lambda META: (1, )
|
||||
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
M=a.shape[0], N=b.shape[1], K=a.shape[1],
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
num_warps=NUM_WARPS)
|
||||
golden = torch.matmul(a, b)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||
|
@@ -872,7 +872,9 @@ def make_tritongpu_ir(mod, num_warps):
|
||||
def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
# Get error in backend due to wrong conversion in expanding async-related instruction.
|
||||
# TODO[Superjomn]: Open it when fixed.
|
||||
# pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_coalesce_pass()
|
||||
|
Reference in New Issue
Block a user