add testing
This commit is contained in:
@@ -2603,7 +2603,7 @@ public:
|
|||||||
|
|
||||||
Value ptr = getPtr(ptrIdx);
|
Value ptr = getPtr(ptrIdx);
|
||||||
|
|
||||||
if (canUseLdmatrix) {
|
if (canUseLdmatrix) { // work with fp16
|
||||||
int sOffset =
|
int sOffset =
|
||||||
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
||||||
PTXBuilder builder;
|
PTXBuilder builder;
|
||||||
@@ -2626,12 +2626,13 @@ public:
|
|||||||
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
||||||
};
|
};
|
||||||
|
|
||||||
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
|
// The struct should have exactly the same element types.
|
||||||
|
Type elemType = resV4.getType().cast<LLVM::LLVMStructType>().getBody()[0];
|
||||||
|
|
||||||
return {extract_val(fp16x2Ty, resV4, getIntAttr(0)),
|
return {extract_val(elemType, resV4, getIntAttr(0)),
|
||||||
extract_val(fp16x2Ty, resV4, getIntAttr(1)),
|
extract_val(elemType, resV4, getIntAttr(1)),
|
||||||
extract_val(fp16x2Ty, resV4, getIntAttr(2)),
|
extract_val(elemType, resV4, getIntAttr(2)),
|
||||||
extract_val(fp16x2Ty, resV4, getIntAttr(3))};
|
extract_val(elemType, resV4, getIntAttr(3))};
|
||||||
} else if (elemBytes == 4 &&
|
} else if (elemBytes == 4 &&
|
||||||
needTrans) { // Use lds.32 to load tf32 matrices
|
needTrans) { // Use lds.32 to load tf32 matrices
|
||||||
Value ptr2 = getPtr(ptrIdx + 1);
|
Value ptr2 = getPtr(ptrIdx + 1);
|
||||||
@@ -2658,9 +2659,9 @@ public:
|
|||||||
elems[3] =
|
elems[3] =
|
||||||
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return {elems[0], elems[1], elems[2], elems[3]};
|
return {elems[0], elems[1], elems[2], elems[3]};
|
||||||
} else if (elemBytes == 1 && needTrans) {
|
|
||||||
|
} else if (elemBytes == 1 && needTrans) { // work with int8
|
||||||
std::array<std::array<Value, 4>, 2> ptrs;
|
std::array<std::array<Value, 4>, 2> ptrs;
|
||||||
ptrs[0] = {
|
ptrs[0] = {
|
||||||
getPtr(ptrIdx),
|
getPtr(ptrIdx),
|
||||||
@@ -2688,17 +2689,18 @@ public:
|
|||||||
|
|
||||||
Value i8Elems[4][4];
|
Value i8Elems[4][4];
|
||||||
Type elemTy = type::i8Ty(ctx);
|
Type elemTy = type::i8Ty(ctx);
|
||||||
|
Type elemPtrTy = ptr_ty(elemTy);
|
||||||
if (kOrder == 1) {
|
if (kOrder == 1) {
|
||||||
Value offset = i32_val(sOffsetElem);
|
Value offset = i32_val(sOffsetElem);
|
||||||
|
|
||||||
for (int i = 0; i < 2; ++i)
|
for (int i = 0; i < 2; ++i)
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], offset));
|
i8Elems[i][j] = load(gep(elemPtrTy, ptrs[i][j], offset));
|
||||||
|
|
||||||
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
||||||
for (int i = 2; i < 4; ++i)
|
for (int i = 2; i < 4; ++i)
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[i][j] = load(gep(elemTy, ptrs[i - 2][j], offset));
|
i8Elems[i][j] = load(gep(elemPtrTy, ptrs[i - 2][j], offset));
|
||||||
|
|
||||||
for (int m = 0; m < 4; ++m) {
|
for (int m = 0; m < 4; ++m) {
|
||||||
for (int e = 0; e < 4; ++e)
|
for (int e = 0; e < 4; ++e)
|
||||||
@@ -2709,14 +2711,14 @@ public:
|
|||||||
} else { // k first
|
} else { // k first
|
||||||
Value offset = i32_val(sOffsetElem);
|
Value offset = i32_val(sOffsetElem);
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], offset));
|
i8Elems[0][j] = load(gep(elemPtrTy, ptrs[0][j], offset));
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], offset));
|
i8Elems[2][j] = load(gep(elemPtrTy, ptrs[1][j], offset));
|
||||||
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], offset));
|
i8Elems[1][j] = load(gep(elemPtrTy, ptrs[0][j], offset));
|
||||||
for (int j = 0; j < 4; ++j)
|
for (int j = 0; j < 4; ++j)
|
||||||
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], offset));
|
i8Elems[3][j] = load(gep(elemPtrTy, ptrs[1][j], offset));
|
||||||
|
|
||||||
for (int m = 0; m < 4; ++m) {
|
for (int m = 0; m < 4; ++m) {
|
||||||
for (int e = 0; e < 4; ++e)
|
for (int e = 0; e < 4; ++e)
|
||||||
@@ -3501,9 +3503,10 @@ struct MMA16816ConversionHelper {
|
|||||||
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Type elemTy = mmaOut.getType().cast<LLVM::LLVMStructType>().getBody()[0];
|
||||||
for (int i = 0; i < 4; ++i)
|
for (int i = 0; i < 4; ++i)
|
||||||
fc[m * colsPerThread + 4 * n + i] =
|
fc[m * colsPerThread + 4 * n + i] =
|
||||||
extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i));
|
extract_val(elemTy, mmaOut, getIntAttr(i));
|
||||||
};
|
};
|
||||||
|
|
||||||
for (int k = 0; k < numRepK; ++k)
|
for (int k = 0; k < numRepK; ++k)
|
||||||
@@ -3511,9 +3514,14 @@ struct MMA16816ConversionHelper {
|
|||||||
for (int n = 0; n < numRepN; ++n)
|
for (int n = 0; n < numRepN; ++n)
|
||||||
callMma(2 * m, n, 2 * k);
|
callMma(2 * m, n, 2 * k);
|
||||||
|
|
||||||
|
// bitcast to fp32 in bulk
|
||||||
|
for (auto &elem : fc) {
|
||||||
|
elem = bitcast(elem, type::i32Ty(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
// replace with new packed result
|
// replace with new packed result
|
||||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||||
ctx, SmallVector<Type>(fc.size(), type::f32Ty(ctx)));
|
ctx, SmallVector<Type>(fc.size(), type::i32Ty(ctx)));
|
||||||
Value res = getStructFromElements(loc, fc, rewriter, structTy);
|
Value res = getStructFromElements(loc, fc, rewriter, structTy);
|
||||||
rewriter.replaceOp(op, res);
|
rewriter.replaceOp(op, res);
|
||||||
|
|
||||||
@@ -3607,10 +3615,9 @@ private:
|
|||||||
|
|
||||||
assert(!elems.empty());
|
assert(!elems.empty());
|
||||||
|
|
||||||
Type fp16Ty = type::f16Ty(ctx);
|
Type elemTy = elems[0].getType();
|
||||||
Type fp16x2Ty = vec_ty(fp16Ty, 2);
|
|
||||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||||
ctx, SmallVector<Type>(elems.size(), fp16x2Ty));
|
ctx, SmallVector<Type>(elems.size(), elemTy));
|
||||||
auto result = getStructFromElements(loc, elems, rewriter, structTy);
|
auto result = getStructFromElements(loc, elems, rewriter, structTy);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@@ -3634,161 +3641,6 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper for FMADot conversion.
|
|
||||||
class DotOpFMAConversionHelper {
|
|
||||||
public:
|
|
||||||
MmaEncodingAttr mmaLayout;
|
|
||||||
ArrayRef<unsigned> wpt;
|
|
||||||
|
|
||||||
using ValueTable = std::map<std::pair<int, int>, Value>;
|
|
||||||
|
|
||||||
explicit DotOpFMAConversionHelper(MmaEncodingAttr mmaLayout)
|
|
||||||
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
|
|
||||||
|
|
||||||
// Currently, we can tell whether to use FMAdot only from the operand type,
|
|
||||||
// while in the original code, FMADot requires that both the operand and
|
|
||||||
// result of dot should be fp32.
|
|
||||||
// This method should be safe to use in the cases where tensor core is not
|
|
||||||
// appliable.
|
|
||||||
static bool useFMA(TensorType operand) {
|
|
||||||
return operand.getElementType().isF32();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value loadA(Value tensor, Value llTensor, Value threadId, Location loc,
|
|
||||||
Value smem, ConversionPatternRewriter &rewriter) const {
|
|
||||||
|
|
||||||
auto *ctx = rewriter.getContext();
|
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
|
||||||
auto aShape = tensorTy.getShape();
|
|
||||||
auto aLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
||||||
auto aOrder = aLayout.getOrder();
|
|
||||||
|
|
||||||
bool isARow = aOrder[0] == 1;
|
|
||||||
|
|
||||||
int strideAM = isARow ? aShape[1] : 1;
|
|
||||||
int strideAK = isARow ? 1 : aShape[0];
|
|
||||||
int strideA0 = isARow ? strideAK : strideAM;
|
|
||||||
int strideA1 = isARow ? strideAM : strideAK;
|
|
||||||
int lda = isARow ? strideAM : strideAK;
|
|
||||||
int aPerPhase = aLayout.getPerPhase();
|
|
||||||
int aMaxPhase = aLayout.getMaxPhase();
|
|
||||||
int aNumPtr = 8;
|
|
||||||
int bNumPtr = 8;
|
|
||||||
int aVec = 2;
|
|
||||||
|
|
||||||
Value _0 = i32_val(0);
|
|
||||||
Value _1 = i32_val(1);
|
|
||||||
|
|
||||||
Value mContig = _1;
|
|
||||||
Value nContig = _1;
|
|
||||||
|
|
||||||
Value offA0 = isARow ? _0 : mul(threadId, mContig);
|
|
||||||
Value offA1 = isARow ? mul(threadId, mContig) : _0;
|
|
||||||
SmallVector<Value> aOff(aNumPtr);
|
|
||||||
for (int i = 0; i < aNumPtr; ++i) {
|
|
||||||
aOff[i] =
|
|
||||||
add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Type f32PtrTy = ptr_ty(f32_ty);
|
|
||||||
SmallVector<Value> aPtrs(aNumPtr);
|
|
||||||
for (int i = 0; i < aNumPtr; ++i)
|
|
||||||
aPtrs[i] = gep(f32PtrTy, llTensor, aOff[i]);
|
|
||||||
|
|
||||||
ValueTable has;
|
|
||||||
|
|
||||||
auto aShapePerCTA = getShapePerCTA(aLayout);
|
|
||||||
auto sizePerThread = getSizePerThread(aLayout);
|
|
||||||
int M = isARow ? aShape[0] : aShape[1];
|
|
||||||
int K = isARow ? aShape[1] : aShape[0];
|
|
||||||
|
|
||||||
for (unsigned k = 0; k < K; k++)
|
|
||||||
for (unsigned m = 0; m < M; m += aShapePerCTA[aOrder[1]])
|
|
||||||
for (unsigned mm = 0; mm < sizePerThread[aOrder[1]]; ++mm) {
|
|
||||||
Value pa = gep(f32PtrTy, aPtrs[0],
|
|
||||||
i32_val((m + mm) * strideAM + k * strideAK));
|
|
||||||
Value va = load(pa);
|
|
||||||
has[{m + mm, k}] = va;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> values;
|
|
||||||
for (auto &item : has)
|
|
||||||
values.push_back(item.second);
|
|
||||||
Type structTy =
|
|
||||||
struct_ty(SmallVector<Type>(values.size(), values[0].getType()));
|
|
||||||
|
|
||||||
return getStructFromElements(loc, values, rewriter, structTy);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value loadB(Value tensor, Value llTensor, Value threadId, Location loc,
|
|
||||||
Value smem, ConversionPatternRewriter &rewriter) const {
|
|
||||||
|
|
||||||
auto *ctx = rewriter.getContext();
|
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
|
||||||
auto bShape = tensorTy.getShape();
|
|
||||||
auto bLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
||||||
auto bOrder = bLayout.getOrder();
|
|
||||||
|
|
||||||
bool isBRow = bOrder[0] == 1;
|
|
||||||
|
|
||||||
int strideBN = isBRow ? 1 : bShape[0];
|
|
||||||
int strideBK = isBRow ? bShape[1] : 1;
|
|
||||||
int strideB0 = isBRow ? strideBN : strideBK;
|
|
||||||
int strideB1 = isBRow ? strideBK : strideBN;
|
|
||||||
int ldb = isBRow ? strideBK : strideBN;
|
|
||||||
int bPerPhase = bLayout.getPerPhase();
|
|
||||||
int bMaxPhase = bLayout.getMaxPhase();
|
|
||||||
int bNumPtr = 8;
|
|
||||||
int bVec = 4;
|
|
||||||
|
|
||||||
auto bShapePerCTA = getShapePerCTA(bLayout);
|
|
||||||
auto sizePerThread = getSizePerThread(bLayout);
|
|
||||||
|
|
||||||
Value _0 = i32_val(0);
|
|
||||||
Value _1 = i32_val(1);
|
|
||||||
|
|
||||||
Value mContig = _1;
|
|
||||||
Value nContig = _1;
|
|
||||||
|
|
||||||
Value offB0 = isBRow ? mul(threadId, nContig) : _0;
|
|
||||||
Value offB1 = isBRow ? _0 : mul(threadId, nContig);
|
|
||||||
SmallVector<Value> bOff(bNumPtr);
|
|
||||||
for (int i = 0; i < bNumPtr; ++i) {
|
|
||||||
bOff[i] =
|
|
||||||
add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Type f32PtrTy = ptr_ty(f32_ty);
|
|
||||||
SmallVector<Value> bPtrs(bNumPtr);
|
|
||||||
for (int i = 0; i < bNumPtr; ++i)
|
|
||||||
bPtrs[i] = gep(f32PtrTy, llTensor, bOff[i]);
|
|
||||||
|
|
||||||
ValueTable hbs;
|
|
||||||
|
|
||||||
int K = isBRow ? bShape[0] : bShape[1];
|
|
||||||
int N = isBRow ? bShape[1] : bShape[0];
|
|
||||||
|
|
||||||
for (int k = 0; k < K; ++k)
|
|
||||||
for (unsigned n = 0; n < N; n += bShapePerCTA[bOrder[0]])
|
|
||||||
for (unsigned nn = 0; nn < sizePerThread[bOrder[0]]; ++nn) {
|
|
||||||
Value pb = gep(f32PtrTy, bPtrs[0],
|
|
||||||
i32_val((n + nn) * strideBN + k * strideBK));
|
|
||||||
Value vb = load(pb);
|
|
||||||
hbs[{n + nn, k}] = vb;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> values;
|
|
||||||
for (auto &item : hbs)
|
|
||||||
values.push_back(item.second);
|
|
||||||
Type structTy =
|
|
||||||
struct_ty(SmallVector<Type>(values.size(), values[0].getType()));
|
|
||||||
|
|
||||||
return getStructFromElements(loc, values, rewriter, structTy);
|
|
||||||
}
|
|
||||||
|
|
||||||
ValueTable extractLoadedOperand(Value llTensor) const { return ValueTable{}; }
|
|
||||||
};
|
|
||||||
|
|
||||||
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
@@ -3842,15 +3694,6 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
|||||||
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
|
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
|
||||||
adaptor.src(), loc, rewriter);
|
adaptor.src(), loc, rewriter);
|
||||||
}
|
}
|
||||||
} else if (DotOpFMAConversionHelper::useFMA(dstTensorTy)) { // fmadot
|
|
||||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
|
||||||
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
|
||||||
res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc),
|
|
||||||
adaptor.src(), loc, rewriter);
|
|
||||||
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
|
|
||||||
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
|
|
||||||
adaptor.src(), loc, rewriter);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
assert(false && "Unsupported mma layout found");
|
assert(false && "Unsupported mma layout found");
|
||||||
}
|
}
|
||||||
@@ -4321,6 +4164,8 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
|||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto threadId = getThreadId(rewriter, loc);
|
auto threadId = getThreadId(rewriter, loc);
|
||||||
|
|
||||||
|
using ValueTable = std::map<std::pair<int, int>, Value>;
|
||||||
|
|
||||||
auto A = op.a();
|
auto A = op.a();
|
||||||
auto B = op.b();
|
auto B = op.b();
|
||||||
auto C = op.c();
|
auto C = op.c();
|
||||||
@@ -4400,8 +4245,7 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
|||||||
for (int i = 0; i < bNumPtr; ++i)
|
for (int i = 0; i < bNumPtr; ++i)
|
||||||
bPtrs[i] = gep(f32PtrTy, adaptor.b(), bOff[i]);
|
bPtrs[i] = gep(f32PtrTy, adaptor.b(), bOff[i]);
|
||||||
|
|
||||||
// TODO initialize ret with $c.
|
ValueTable has, hbs;
|
||||||
DotOpFMAConversionHelper::ValueTable has, hbs;
|
|
||||||
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
||||||
SmallVector<Value> ret = cc;
|
SmallVector<Value> ret = cc;
|
||||||
|
|
||||||
|
@@ -144,3 +144,50 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
|
|||||||
|
|
||||||
torch.set_printoptions(profile="full")
|
torch.set_printoptions(profile="full")
|
||||||
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
||||||
|
|
||||||
|
|
||||||
|
# Precession regression for FMADot is not done yet due to some issue on the optimizer failed to give a blocked layout to dot op.
|
||||||
|
# TODO[Superjomn]: Uncomment this test and continue to finish precession regression latter.
|
||||||
|
# @pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
|
||||||
|
# [128, 256, 128, 4, 128, 256, 32],
|
||||||
|
# [256, 128, 64, 4, 256, 128, 16],
|
||||||
|
# [128, 64, 128, 4, 128, 64, 32],
|
||||||
|
# ])
|
||||||
|
# def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||||
|
# @triton.jit
|
||||||
|
# def matmul_kernel(
|
||||||
|
# a_ptr, b_ptr, c_ptr,
|
||||||
|
# stride_am, stride_ak,
|
||||||
|
# stride_bk, stride_bn,
|
||||||
|
# stride_cm, stride_cn,
|
||||||
|
# K: tl.constexpr,
|
||||||
|
# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
# ):
|
||||||
|
# offs_m = tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
# offs_n = tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
# offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
# a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||||
|
# b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
||||||
|
# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
# for k in range(0, K, BLOCK_SIZE_K):
|
||||||
|
# a = tl.load(a_ptrs)
|
||||||
|
# b = tl.load(b_ptrs)
|
||||||
|
# accumulator += tl.dot(a, b, allow_tf32=True)
|
||||||
|
# a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
|
# b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
# c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||||
|
# tl.store(c_ptrs, accumulator)
|
||||||
|
|
||||||
|
# a = torch.randn((M, K), device='cuda', dtype=torch.float32)
|
||||||
|
# b = torch.randn((K, N), device='cuda', dtype=torch.float)
|
||||||
|
# c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||||
|
# grid = lambda META: (1, )
|
||||||
|
# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||||
|
# stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||||
|
# stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||||
|
# stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||||
|
# K=a.shape[1], BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N,
|
||||||
|
# BLOCK_SIZE_K=block_K, num_warps=num_warps)
|
||||||
|
# golden = torch.matmul(a, b)
|
||||||
|
# torch.testing.assert_close(c, golden)
|
||||||
|
Reference in New Issue
Block a user