[Triton-MLIR][BACKEND] Make mmav1 works on basic cases (#944)
TODO: - Add more cases - Currently, we just set vec to 4 to make the basic cases pass Issue: - the vec in shared layout is different compared to master branch - when vec=1, it encounters CUDA misalignment error, it doesn't work in master branch as well - when setting vec to the value identical to master branch, the MMA works
This commit is contained in:
11
.github/workflows/integration-tests.yml
vendored
11
.github/workflows/integration-tests.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
|
||||
echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-10.15"]'
|
||||
echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]'
|
||||
else
|
||||
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
|
||||
fi
|
||||
@@ -79,11 +79,18 @@ jobs:
|
||||
lit -v "$LIT_TEST_DIR"
|
||||
|
||||
- name: Run python tests
|
||||
if: ${{matrix.runner[0] == 'self-hosted'}}
|
||||
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'A10'}}
|
||||
run: |
|
||||
cd python/tests
|
||||
pytest
|
||||
|
||||
# TODO[Superjomn] Enable all the tests on V100 if available
|
||||
- name: Run python tests on V100
|
||||
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
|
||||
run: |
|
||||
cd python/tests
|
||||
pytest test_gemm.py::test_gemm_no_scf_for_mmav1
|
||||
|
||||
- name: Run CXX unittests
|
||||
run: |
|
||||
cd python/
|
||||
|
@@ -94,13 +94,17 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
// ---- begin version 1 ----
|
||||
if (version == 1) {
|
||||
bool is_row = order[0] != 0;
|
||||
bool is_vec4 = opIdx == 0 ? is_row && (shape[order[0]] <= 16) :
|
||||
!is_row && (shape[order[0]] <= 16);
|
||||
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
|
||||
is_row && (shape[order[0]] <= 16);
|
||||
// TODO[Superjomn]: Support the case when is_vec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
|
||||
is_vec4 = true;
|
||||
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
|
||||
((is_row && !is_vec4) ? 2 : 1);
|
||||
int rep = 2 * pack_size;
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
return $_get(context, 2 * rep, perPhase, maxPhase, order);
|
||||
int vec = 2 * rep;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
// ---- begin version 2 ----
|
||||
|
@@ -39,15 +39,6 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
// Forward declaration necessary functions locates in TritonGPUToLLVM.cpp .
|
||||
llvm::SmallVector<mlir::Value>
|
||||
getElementsFromStruct(mlir::Location loc, mlir::Value llvmStruct,
|
||||
mlir::ConversionPatternRewriter &rewriter);
|
||||
|
||||
mlir::LLVM::SharedMemoryObject
|
||||
getSharedMemoryObjectFromStruct(mlir::Location loc, mlir::Value llvmStruct,
|
||||
mlir::ConversionPatternRewriter &rewriter);
|
||||
|
||||
// Helper for conversion of DotOp with mma<version=1>, that is sm<80
|
||||
struct DotOpMmaV1ConversionHelper {
|
||||
MmaEncodingAttr mmaLayout;
|
||||
@@ -710,17 +701,13 @@ public:
|
||||
if (kOrder == 1) {
|
||||
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
||||
elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
||||
elems[2] =
|
||||
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] =
|
||||
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
||||
elems[2] = load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] = load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
||||
} else {
|
||||
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
||||
elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
||||
elems[1] =
|
||||
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] =
|
||||
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
||||
elems[1] = load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] = load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
||||
}
|
||||
return {elems[0], elems[1], elems[2], elems[3]};
|
||||
|
||||
@@ -952,7 +939,6 @@ struct MMA16816ConversionHelper {
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
|
||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto layout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
|
||||
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
||||
aTensorTy.getShape().end());
|
||||
@@ -973,12 +959,13 @@ struct MMA16816ConversionHelper {
|
||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||
Value warpM = getWarpM(shape[0]);
|
||||
// load from smem
|
||||
int wpt = std::min<int>(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM);
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, wpt /*wpt*/,
|
||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/,
|
||||
true /*isA*/);
|
||||
int wpt =
|
||||
std::min<int>(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM);
|
||||
loadFn =
|
||||
getLoadMatrixFn(tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/,
|
||||
{mmaInstrM, mmaInstrK} /*instrShape*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/,
|
||||
ha /*vals*/, true /*isA*/);
|
||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
// load from registers, used in gemm fuse
|
||||
// TODO(Superjomn) Port the logic.
|
||||
@@ -1000,7 +987,6 @@ struct MMA16816ConversionHelper {
|
||||
Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
|
||||
ValueTable hb;
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
|
||||
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
|
||||
tensorTy.getShape().end());
|
||||
@@ -1017,12 +1003,13 @@ struct MMA16816ConversionHelper {
|
||||
int numRepN = getNumRepN(tensorTy, shape[1]);
|
||||
|
||||
Value warpN = getWarpN(shape[1]);
|
||||
int wpt = std::min<int>(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN);
|
||||
auto loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, wpt /*wpt*/,
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/,
|
||||
false /*isA*/);
|
||||
int wpt =
|
||||
std::min<int>(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN);
|
||||
auto loadFn =
|
||||
getLoadMatrixFn(tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/,
|
||||
{mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/,
|
||||
hb /*vals*/, false /*isA*/);
|
||||
|
||||
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
@@ -1167,6 +1154,7 @@ private:
|
||||
SmallVector<Value> ptrs(numPtrs);
|
||||
|
||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
|
||||
Type smemPtrTy = helper.getShemPtrTy();
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
ptrs[i] =
|
||||
@@ -1292,7 +1280,6 @@ struct DotOpFMAConversionHelper {
|
||||
auto blockedLayout = dotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
auto shapePerCTA = getShapePerCTA(blockedLayout);
|
||||
auto sizePerThread = getSizePerThread(blockedLayout);
|
||||
auto order = blockedLayout.getOrder();
|
||||
|
||||
// TODO[Superjomn]: we assume the k aixs is fixed for $a and $b here, fix it
|
||||
// if not.
|
||||
@@ -1342,17 +1329,15 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
|
||||
// TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp.
|
||||
bool transA = false;
|
||||
if (transA) {
|
||||
std::swap(shape[0], shape[1]);
|
||||
std::swap(order[0], order[1]);
|
||||
}
|
||||
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
||||
// TODO[Superjomn]: Support the case when isAVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with different
|
||||
// ld vector width.
|
||||
isAVec4 = true;
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
@@ -1362,6 +1347,16 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
|
||||
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
|
||||
|
||||
auto [offsetAM, offsetAK, _0, _1] =
|
||||
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
|
||||
// TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp.
|
||||
bool transA = false;
|
||||
if (transA) {
|
||||
std::swap(shape[0], shape[1]);
|
||||
std::swap(offsetAM, offsetAK);
|
||||
std::swap(order[0], order[1]);
|
||||
}
|
||||
|
||||
int vecA = sharedLayout.getVec();
|
||||
|
||||
auto strides = smemObj.strides;
|
||||
@@ -1373,9 +1368,6 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
int strideRepM = wpt[0] * fpw[0] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
auto [offsetAM, offsetAK, _0, _1] =
|
||||
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
|
||||
|
||||
// swizzling
|
||||
int perPhaseA = sharedLayout.getPerPhase();
|
||||
int maxPhaseA = sharedLayout.getMaxPhase();
|
||||
@@ -1398,19 +1390,14 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
}
|
||||
|
||||
Type f16x2Ty = vec_ty(f16_ty, 2);
|
||||
// One thread get 8 elements as result
|
||||
Type retTy =
|
||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector(8, type::f32Ty(ctx)));
|
||||
|
||||
// prepare arguments
|
||||
SmallVector<Value> ptrA(numPtrA);
|
||||
|
||||
std::map<std::pair<int, int>, std::pair<Value, Value>> has;
|
||||
auto smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
for (int i = 0; i < numPtrA; i++)
|
||||
ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]);
|
||||
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
|
||||
|
||||
auto instrShape = getMmaInstrShape();
|
||||
unsigned numM = std::max<int>(rep[0] * shape[0] / (spw[0] * wpt[0]), 1);
|
||||
|
||||
Type f16PtrTy = ptr_ty(f16_ty);
|
||||
@@ -1420,7 +1407,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
};
|
||||
auto loadA = [&](int m, int k) {
|
||||
int offidx = (isARow ? k / 4 : m) % numPtrA;
|
||||
Value thePtrA = gep(f16PtrTy, smem, offA[offidx]);
|
||||
Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]);
|
||||
|
||||
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
||||
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
||||
@@ -1446,12 +1433,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
if (!has.count({m, k}))
|
||||
loadA(m, k);
|
||||
|
||||
SmallVector<Value> elems;
|
||||
elems.reserve(has.size() * 2);
|
||||
auto vecTy = vec_ty(f16_ty, 2);
|
||||
for (auto item : has) { // has is a map, the key should be ordered.
|
||||
elems.push_back(item.second.first);
|
||||
elems.push_back(item.second.second);
|
||||
@@ -1466,7 +1451,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// smem
|
||||
Value smem = smemObj.base;
|
||||
auto strides = smemObj.strides;
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
@@ -1478,21 +1462,20 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
|
||||
// TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp.
|
||||
bool transB = false;
|
||||
|
||||
if (transB) {
|
||||
std::swap(order[0], order[1]);
|
||||
std::swap(shape[0], shape[1]);
|
||||
}
|
||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
|
||||
bool isBRow = order[0] != 0;
|
||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||
// TODO[Superjomn]: Support the case when isBVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with different
|
||||
// ld vector width.
|
||||
isBVec4 = true;
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
||||
int vecB = sharedLayout.getVec();
|
||||
|
||||
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
||||
Value strideBK = isBRow ? strides[0] : i32_val(1);
|
||||
Value strideB0 = isBRow ? strideBN : strideBK;
|
||||
@@ -1500,24 +1483,29 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
int strideRepN = wpt[1] * fpw[1] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
// TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp.
|
||||
bool transB = false;
|
||||
|
||||
auto [_0, _1, offsetBN, offsetBK] =
|
||||
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
|
||||
if (transB) {
|
||||
std::swap(order[0], order[1]);
|
||||
std::swap(shape[0], shape[1]);
|
||||
std::swap(offsetBK, offsetBN);
|
||||
}
|
||||
|
||||
// swizzling
|
||||
int perPhaseA = sharedLayout.getPerPhase();
|
||||
int maxPhaseA = sharedLayout.getMaxPhase();
|
||||
int perPhaseB = sharedLayout.getPerPhase();
|
||||
int maxPhaseB = sharedLayout.getMaxPhase();
|
||||
int stepB0 = isBRow ? strideRepN : strideRepK;
|
||||
int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1);
|
||||
int NK = shape[0];
|
||||
|
||||
auto [_0, _1, offsetBN, offsetBK] =
|
||||
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
|
||||
if (transB)
|
||||
std::swap(offsetBK, offsetBN);
|
||||
|
||||
Value offB0 = isBRow ? offsetBN : offsetBK;
|
||||
Value offB1 = isBRow ? offsetBK : offsetBN;
|
||||
Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB));
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
|
||||
offB0 = add(offB0, cSwizzleOffset);
|
||||
SmallVector<Value> offB(numPtrB);
|
||||
for (int i = 0; i < numPtrB; ++i) {
|
||||
@@ -1549,6 +1537,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
|
||||
mul(i32_val(stepBK), strideBK));
|
||||
Value pb = gep(f16PtrTy, thePtrB, offset);
|
||||
|
||||
Value hb =
|
||||
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
||||
// record lds that needs to be moved
|
||||
@@ -1651,9 +1640,12 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand(
|
||||
SmallVector<Value> elems =
|
||||
getElementsFromStruct(llStruct.getLoc(), llStruct, rewriter);
|
||||
|
||||
for (int k = 0, offset = 0, i = 0; k < NK && offset < elems.size();
|
||||
k += 4, i++, offset += 2) {
|
||||
int offset = 0;
|
||||
for (int i = 0; offset < elems.size(); ++i) {
|
||||
for (int k = 0; k < NK; k += 4) {
|
||||
rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]);
|
||||
offset += 2;
|
||||
}
|
||||
}
|
||||
|
||||
return rcds;
|
||||
@@ -1675,9 +1667,7 @@ Value DotOpFMAConversionHelper::loadA(
|
||||
int strideAK = isARow ? 1 : aShape[0];
|
||||
int strideA0 = isARow ? strideAK : strideAM;
|
||||
int strideA1 = isARow ? strideAM : strideAK;
|
||||
int lda = isARow ? strideAM : strideAK;
|
||||
int aNumPtr = 8;
|
||||
int bNumPtr = 8;
|
||||
int NK = aShape[1];
|
||||
|
||||
auto shapePerCTA = getShapePerCTA(dLayout);
|
||||
@@ -1686,13 +1676,11 @@ Value DotOpFMAConversionHelper::loadA(
|
||||
Value _0 = i32_val(0);
|
||||
|
||||
Value mContig = i32_val(sizePerThread[order[1]]);
|
||||
Value nContig = i32_val(sizePerThread[order[0]]);
|
||||
|
||||
// threadId in blocked layout
|
||||
auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc);
|
||||
|
||||
Value threadIdM = threadIds[0];
|
||||
Value threadIdN = threadIds[1];
|
||||
|
||||
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
|
||||
Value offA1 = isARow ? mul(threadIdM, mContig) : _0;
|
||||
@@ -1745,7 +1733,6 @@ Value DotOpFMAConversionHelper::loadB(
|
||||
int strideBK = isBRow ? bShape[1] : 1;
|
||||
int strideB0 = isBRow ? strideBN : strideBK;
|
||||
int strideB1 = isBRow ? strideBK : strideBN;
|
||||
int ldb = isBRow ? strideBK : strideBN;
|
||||
int bNumPtr = 8;
|
||||
int NK = bShape[0];
|
||||
|
||||
@@ -1754,7 +1741,6 @@ Value DotOpFMAConversionHelper::loadB(
|
||||
|
||||
Value _0 = i32_val(0);
|
||||
|
||||
Value mContig = i32_val(sizePerThread[order[1]]);
|
||||
Value nContig = i32_val(sizePerThread[order[0]]);
|
||||
|
||||
// threadId in blocked layout
|
||||
|
@@ -62,12 +62,11 @@ namespace LLVM {
|
||||
static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
||||
|
||||
// A helper function for using printf in LLVM conversion.
|
||||
void llPrintf(StringRef msg, ValueRange args,
|
||||
void vprintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
// Helper function
|
||||
#define tid_val() getThreadId(rewriter, loc)
|
||||
#define llprintf(fmt, ...) LLVM::llPrintf(fmt, {__VA_ARGS__}, rewriter)
|
||||
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
std::string elem_repr, ConversionPatternRewriter &builder);
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
@@ -3537,8 +3536,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
SmallVector<Value> resVals(resSize);
|
||||
|
||||
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
||||
auto ha = has[{m, k}];
|
||||
auto hb = hbs[{n, k}];
|
||||
auto ha = has.at({m, k});
|
||||
auto hb = hbs.at({n, k});
|
||||
std::vector<size_t> idx{{
|
||||
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
|
||||
(m * 2 + 0) + (n * 4 + 1) * numM,
|
||||
@@ -3554,13 +3553,13 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
|
||||
auto *resOprs = builder.newListOperand(8, "=f");
|
||||
auto *AOprs = builder.newListOperand({
|
||||
{ha.first, "f"},
|
||||
{ha.second, "f"},
|
||||
{ha.first, "r"},
|
||||
{ha.second, "r"},
|
||||
});
|
||||
|
||||
auto *BOprs = builder.newListOperand({
|
||||
{hb.first, "f"},
|
||||
{hb.second, "f"},
|
||||
{hb.first, "r"},
|
||||
{hb.second, "r"},
|
||||
});
|
||||
auto *COprs = builder.newListOperand();
|
||||
for (int i = 0; i < 8; ++i)
|
||||
@@ -4806,11 +4805,23 @@ namespace mlir {
|
||||
|
||||
namespace LLVM {
|
||||
|
||||
void llPrintf(StringRef msg, ValueRange args,
|
||||
void vprintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
PrintfOpConversion::llPrintf(msg, args, rewriter);
|
||||
}
|
||||
|
||||
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
std::string elem_repr, ConversionPatternRewriter &builder) {
|
||||
std::string fmt = info + " t-%d ";
|
||||
std::vector<Value> new_arr({thread});
|
||||
for (int i = 0; i < arr.size(); ++i) {
|
||||
fmt += elem_repr + ((i == arr.size() - 1) ? "" : ", ");
|
||||
new_arr.push_back(arr[i]);
|
||||
}
|
||||
|
||||
vprintf(fmt, new_arr, builder);
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
|
||||
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||
|
@@ -111,6 +111,8 @@
|
||||
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define tid_val() getThreadId(rewriter, loc)
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
using namespace mlir::triton;
|
||||
|
@@ -756,6 +756,7 @@ public:
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||
|
||||
auto newRetType = RankedTensorType::get(
|
||||
retShape, oldRetType.getElementType(),
|
||||
triton::gpu::MmaEncodingAttr::get(
|
||||
|
@@ -1383,6 +1383,11 @@ void init_triton_translation(py::module &m) {
|
||||
llvm::SMDiagnostic error;
|
||||
std::unique_ptr<llvm::Module> module =
|
||||
llvm::parseIR(buffer->getMemBufferRef(), error, context);
|
||||
if (!module)
|
||||
llvm::report_fatal_error(
|
||||
"failed to parse IR: " + error.getMessage() +
|
||||
"lineno: " + std::to_string(error.getLineNo()));
|
||||
|
||||
// translate module to PTX
|
||||
auto ptxCode =
|
||||
triton::translateLLVMIRToPTX(*module, capability, version);
|
||||
|
@@ -292,3 +292,21 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err))
|
||||
else:
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||
|
||||
|
||||
# NOTE this is useful only on Volta GPU.
|
||||
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
||||
(shape, num_warps, trans_a, trans_b)
|
||||
for shape in [
|
||||
[16, 16, 16],
|
||||
[16, 16, 32],
|
||||
[32, 16, 16],
|
||||
[32, 32, 32],
|
||||
[128, 16, 16],
|
||||
]
|
||||
for num_warps in [1]
|
||||
for trans_a in [False]
|
||||
for trans_b in [False]
|
||||
])
|
||||
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)
|
||||
|
Reference in New Issue
Block a user