[Triton-MLIR][BACKEND] Make mmav1 works on basic cases (#944)
TODO: - Add more cases - Currently, we just set vec to 4 to make the basic cases pass Issue: - the vec in shared layout is different compared to master branch - when vec=1, it encounters CUDA misalignment error, it doesn't work in master branch as well - when setting vec to the value identical to master branch, the MMA works
This commit is contained in:
11
.github/workflows/integration-tests.yml
vendored
11
.github/workflows/integration-tests.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
|||||||
id: set-matrix
|
id: set-matrix
|
||||||
run: |
|
run: |
|
||||||
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
|
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
|
||||||
echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-10.15"]'
|
echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]'
|
||||||
else
|
else
|
||||||
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
|
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
|
||||||
fi
|
fi
|
||||||
@@ -79,11 +79,18 @@ jobs:
|
|||||||
lit -v "$LIT_TEST_DIR"
|
lit -v "$LIT_TEST_DIR"
|
||||||
|
|
||||||
- name: Run python tests
|
- name: Run python tests
|
||||||
if: ${{matrix.runner[0] == 'self-hosted'}}
|
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'A10'}}
|
||||||
run: |
|
run: |
|
||||||
cd python/tests
|
cd python/tests
|
||||||
pytest
|
pytest
|
||||||
|
|
||||||
|
# TODO[Superjomn] Enable all the tests on V100 if available
|
||||||
|
- name: Run python tests on V100
|
||||||
|
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
|
||||||
|
run: |
|
||||||
|
cd python/tests
|
||||||
|
pytest test_gemm.py::test_gemm_no_scf_for_mmav1
|
||||||
|
|
||||||
- name: Run CXX unittests
|
- name: Run CXX unittests
|
||||||
run: |
|
run: |
|
||||||
cd python/
|
cd python/
|
||||||
|
@@ -94,13 +94,17 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
|||||||
// ---- begin version 1 ----
|
// ---- begin version 1 ----
|
||||||
if (version == 1) {
|
if (version == 1) {
|
||||||
bool is_row = order[0] != 0;
|
bool is_row = order[0] != 0;
|
||||||
bool is_vec4 = opIdx == 0 ? is_row && (shape[order[0]] <= 16) :
|
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
|
||||||
!is_row && (shape[order[0]] <= 16);
|
is_row && (shape[order[0]] <= 16);
|
||||||
|
// TODO[Superjomn]: Support the case when is_vec4=false later
|
||||||
|
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
|
||||||
|
is_vec4 = true;
|
||||||
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
|
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
|
||||||
((is_row && !is_vec4) ? 2 : 1);
|
((is_row && !is_vec4) ? 2 : 1);
|
||||||
int rep = 2 * pack_size;
|
int rep = 2 * pack_size;
|
||||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||||
return $_get(context, 2 * rep, perPhase, maxPhase, order);
|
int vec = 2 * rep;
|
||||||
|
return $_get(context, vec, perPhase, maxPhase, order);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---- begin version 2 ----
|
// ---- begin version 2 ----
|
||||||
|
@@ -39,15 +39,6 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
|||||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||||
|
|
||||||
// Forward declaration necessary functions locates in TritonGPUToLLVM.cpp .
|
|
||||||
llvm::SmallVector<mlir::Value>
|
|
||||||
getElementsFromStruct(mlir::Location loc, mlir::Value llvmStruct,
|
|
||||||
mlir::ConversionPatternRewriter &rewriter);
|
|
||||||
|
|
||||||
mlir::LLVM::SharedMemoryObject
|
|
||||||
getSharedMemoryObjectFromStruct(mlir::Location loc, mlir::Value llvmStruct,
|
|
||||||
mlir::ConversionPatternRewriter &rewriter);
|
|
||||||
|
|
||||||
// Helper for conversion of DotOp with mma<version=1>, that is sm<80
|
// Helper for conversion of DotOp with mma<version=1>, that is sm<80
|
||||||
struct DotOpMmaV1ConversionHelper {
|
struct DotOpMmaV1ConversionHelper {
|
||||||
MmaEncodingAttr mmaLayout;
|
MmaEncodingAttr mmaLayout;
|
||||||
@@ -710,17 +701,13 @@ public:
|
|||||||
if (kOrder == 1) {
|
if (kOrder == 1) {
|
||||||
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
||||||
elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
||||||
elems[2] =
|
elems[2] = load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
||||||
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
elems[3] = load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
||||||
elems[3] =
|
|
||||||
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
|
||||||
} else {
|
} else {
|
||||||
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
||||||
elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
||||||
elems[1] =
|
elems[1] = load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
||||||
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
elems[3] = load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
||||||
elems[3] =
|
|
||||||
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
|
||||||
}
|
}
|
||||||
return {elems[0], elems[1], elems[2], elems[3]};
|
return {elems[0], elems[1], elems[2], elems[3]};
|
||||||
|
|
||||||
@@ -952,7 +939,6 @@ struct MMA16816ConversionHelper {
|
|||||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||||
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
|
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
|
||||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
auto layout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
||||||
|
|
||||||
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
||||||
aTensorTy.getShape().end());
|
aTensorTy.getShape().end());
|
||||||
@@ -973,12 +959,13 @@ struct MMA16816ConversionHelper {
|
|||||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||||
Value warpM = getWarpM(shape[0]);
|
Value warpM = getWarpM(shape[0]);
|
||||||
// load from smem
|
// load from smem
|
||||||
int wpt = std::min<int>(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM);
|
int wpt =
|
||||||
loadFn = getLoadMatrixFn(
|
std::min<int>(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM);
|
||||||
tensor, smemObj, mmaLayout, wpt /*wpt*/,
|
loadFn =
|
||||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/,
|
getLoadMatrixFn(tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/,
|
||||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/,
|
{mmaInstrM, mmaInstrK} /*instrShape*/,
|
||||||
true /*isA*/);
|
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/,
|
||||||
|
ha /*vals*/, true /*isA*/);
|
||||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||||
// load from registers, used in gemm fuse
|
// load from registers, used in gemm fuse
|
||||||
// TODO(Superjomn) Port the logic.
|
// TODO(Superjomn) Port the logic.
|
||||||
@@ -1000,7 +987,6 @@ struct MMA16816ConversionHelper {
|
|||||||
Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
|
Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
|
||||||
ValueTable hb;
|
ValueTable hb;
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
auto layout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
|
||||||
|
|
||||||
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
|
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
|
||||||
tensorTy.getShape().end());
|
tensorTy.getShape().end());
|
||||||
@@ -1017,12 +1003,13 @@ struct MMA16816ConversionHelper {
|
|||||||
int numRepN = getNumRepN(tensorTy, shape[1]);
|
int numRepN = getNumRepN(tensorTy, shape[1]);
|
||||||
|
|
||||||
Value warpN = getWarpN(shape[1]);
|
Value warpN = getWarpN(shape[1]);
|
||||||
int wpt = std::min<int>(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN);
|
int wpt =
|
||||||
auto loadFn = getLoadMatrixFn(
|
std::min<int>(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN);
|
||||||
tensor, smemObj, mmaLayout, wpt /*wpt*/,
|
auto loadFn =
|
||||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
|
getLoadMatrixFn(tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/,
|
||||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/,
|
{mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||||
false /*isA*/);
|
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/,
|
||||||
|
hb /*vals*/, false /*isA*/);
|
||||||
|
|
||||||
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||||
for (int k = 0; k < numRepK; ++k)
|
for (int k = 0; k < numRepK; ++k)
|
||||||
@@ -1167,6 +1154,7 @@ private:
|
|||||||
SmallVector<Value> ptrs(numPtrs);
|
SmallVector<Value> ptrs(numPtrs);
|
||||||
|
|
||||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||||
|
|
||||||
Type smemPtrTy = helper.getShemPtrTy();
|
Type smemPtrTy = helper.getShemPtrTy();
|
||||||
for (int i = 0; i < numPtrs; ++i) {
|
for (int i = 0; i < numPtrs; ++i) {
|
||||||
ptrs[i] =
|
ptrs[i] =
|
||||||
@@ -1292,7 +1280,6 @@ struct DotOpFMAConversionHelper {
|
|||||||
auto blockedLayout = dotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
auto blockedLayout = dotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||||
auto shapePerCTA = getShapePerCTA(blockedLayout);
|
auto shapePerCTA = getShapePerCTA(blockedLayout);
|
||||||
auto sizePerThread = getSizePerThread(blockedLayout);
|
auto sizePerThread = getSizePerThread(blockedLayout);
|
||||||
auto order = blockedLayout.getOrder();
|
|
||||||
|
|
||||||
// TODO[Superjomn]: we assume the k aixs is fixed for $a and $b here, fix it
|
// TODO[Superjomn]: we assume the k aixs is fixed for $a and $b here, fix it
|
||||||
// if not.
|
// if not.
|
||||||
@@ -1342,17 +1329,15 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||||
sharedLayout.getOrder().end());
|
sharedLayout.getOrder().end());
|
||||||
|
|
||||||
// TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp.
|
|
||||||
bool transA = false;
|
|
||||||
if (transA) {
|
|
||||||
std::swap(shape[0], shape[1]);
|
|
||||||
std::swap(order[0], order[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||||
|
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||||
|
|
||||||
bool isARow = order[0] != 0;
|
bool isARow = order[0] != 0;
|
||||||
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
||||||
|
// TODO[Superjomn]: Support the case when isAVec4=false later
|
||||||
|
// Currently, we only support ld.v2, for the mma layout varies with different
|
||||||
|
// ld vector width.
|
||||||
|
isAVec4 = true;
|
||||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||||
|
|
||||||
SmallVector<int> fpw({2, 2, 1});
|
SmallVector<int> fpw({2, 2, 1});
|
||||||
@@ -1362,6 +1347,16 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
|
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
|
||||||
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
|
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
|
||||||
|
|
||||||
|
auto [offsetAM, offsetAK, _0, _1] =
|
||||||
|
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
|
||||||
|
// TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp.
|
||||||
|
bool transA = false;
|
||||||
|
if (transA) {
|
||||||
|
std::swap(shape[0], shape[1]);
|
||||||
|
std::swap(offsetAM, offsetAK);
|
||||||
|
std::swap(order[0], order[1]);
|
||||||
|
}
|
||||||
|
|
||||||
int vecA = sharedLayout.getVec();
|
int vecA = sharedLayout.getVec();
|
||||||
|
|
||||||
auto strides = smemObj.strides;
|
auto strides = smemObj.strides;
|
||||||
@@ -1373,9 +1368,6 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
int strideRepM = wpt[0] * fpw[0] * 8;
|
int strideRepM = wpt[0] * fpw[0] * 8;
|
||||||
int strideRepK = 1;
|
int strideRepK = 1;
|
||||||
|
|
||||||
auto [offsetAM, offsetAK, _0, _1] =
|
|
||||||
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
|
|
||||||
|
|
||||||
// swizzling
|
// swizzling
|
||||||
int perPhaseA = sharedLayout.getPerPhase();
|
int perPhaseA = sharedLayout.getPerPhase();
|
||||||
int maxPhaseA = sharedLayout.getMaxPhase();
|
int maxPhaseA = sharedLayout.getMaxPhase();
|
||||||
@@ -1398,19 +1390,14 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Type f16x2Ty = vec_ty(f16_ty, 2);
|
Type f16x2Ty = vec_ty(f16_ty, 2);
|
||||||
// One thread get 8 elements as result
|
|
||||||
Type retTy =
|
|
||||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector(8, type::f32Ty(ctx)));
|
|
||||||
|
|
||||||
// prepare arguments
|
// prepare arguments
|
||||||
SmallVector<Value> ptrA(numPtrA);
|
SmallVector<Value> ptrA(numPtrA);
|
||||||
|
|
||||||
std::map<std::pair<int, int>, std::pair<Value, Value>> has;
|
std::map<std::pair<int, int>, std::pair<Value, Value>> has;
|
||||||
auto smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
|
||||||
for (int i = 0; i < numPtrA; i++)
|
for (int i = 0; i < numPtrA; i++)
|
||||||
ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]);
|
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
|
||||||
|
|
||||||
auto instrShape = getMmaInstrShape();
|
|
||||||
unsigned numM = std::max<int>(rep[0] * shape[0] / (spw[0] * wpt[0]), 1);
|
unsigned numM = std::max<int>(rep[0] * shape[0] / (spw[0] * wpt[0]), 1);
|
||||||
|
|
||||||
Type f16PtrTy = ptr_ty(f16_ty);
|
Type f16PtrTy = ptr_ty(f16_ty);
|
||||||
@@ -1420,7 +1407,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
};
|
};
|
||||||
auto loadA = [&](int m, int k) {
|
auto loadA = [&](int m, int k) {
|
||||||
int offidx = (isARow ? k / 4 : m) % numPtrA;
|
int offidx = (isARow ? k / 4 : m) % numPtrA;
|
||||||
Value thePtrA = gep(f16PtrTy, smem, offA[offidx]);
|
Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]);
|
||||||
|
|
||||||
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
||||||
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
||||||
@@ -1446,12 +1433,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
|
|
||||||
for (unsigned k = 0; k < NK; k += 4)
|
for (unsigned k = 0; k < NK; k += 4)
|
||||||
for (unsigned m = 0; m < numM / 2; ++m)
|
for (unsigned m = 0; m < numM / 2; ++m)
|
||||||
if (!has.count({m, k}))
|
|
||||||
loadA(m, k);
|
loadA(m, k);
|
||||||
|
|
||||||
SmallVector<Value> elems;
|
SmallVector<Value> elems;
|
||||||
elems.reserve(has.size() * 2);
|
elems.reserve(has.size() * 2);
|
||||||
auto vecTy = vec_ty(f16_ty, 2);
|
|
||||||
for (auto item : has) { // has is a map, the key should be ordered.
|
for (auto item : has) { // has is a map, the key should be ordered.
|
||||||
elems.push_back(item.second.first);
|
elems.push_back(item.second.first);
|
||||||
elems.push_back(item.second.second);
|
elems.push_back(item.second.second);
|
||||||
@@ -1466,7 +1451,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
// smem
|
// smem
|
||||||
Value smem = smemObj.base;
|
|
||||||
auto strides = smemObj.strides;
|
auto strides = smemObj.strides;
|
||||||
|
|
||||||
auto *ctx = rewriter.getContext();
|
auto *ctx = rewriter.getContext();
|
||||||
@@ -1478,21 +1462,20 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||||
sharedLayout.getOrder().end());
|
sharedLayout.getOrder().end());
|
||||||
|
|
||||||
// TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp.
|
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||||
bool transB = false;
|
|
||||||
|
|
||||||
if (transB) {
|
|
||||||
std::swap(order[0], order[1]);
|
|
||||||
std::swap(shape[0], shape[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isBRow = order[0] != 0;
|
bool isBRow = order[0] != 0;
|
||||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||||
|
// TODO[Superjomn]: Support the case when isBVec4=false later
|
||||||
|
// Currently, we only support ld.v2, for the mma layout varies with different
|
||||||
|
// ld vector width.
|
||||||
|
isBVec4 = true;
|
||||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||||
SmallVector<int> fpw({2, 2, 1});
|
SmallVector<int> fpw({2, 2, 1});
|
||||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
||||||
int vecB = sharedLayout.getVec();
|
int vecB = sharedLayout.getVec();
|
||||||
|
|
||||||
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
||||||
Value strideBK = isBRow ? strides[0] : i32_val(1);
|
Value strideBK = isBRow ? strides[0] : i32_val(1);
|
||||||
Value strideB0 = isBRow ? strideBN : strideBK;
|
Value strideB0 = isBRow ? strideBN : strideBK;
|
||||||
@@ -1500,24 +1483,29 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
int strideRepN = wpt[1] * fpw[1] * 8;
|
int strideRepN = wpt[1] * fpw[1] * 8;
|
||||||
int strideRepK = 1;
|
int strideRepK = 1;
|
||||||
|
|
||||||
|
// TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp.
|
||||||
|
bool transB = false;
|
||||||
|
|
||||||
|
auto [_0, _1, offsetBN, offsetBK] =
|
||||||
|
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
|
||||||
|
if (transB) {
|
||||||
|
std::swap(order[0], order[1]);
|
||||||
|
std::swap(shape[0], shape[1]);
|
||||||
|
std::swap(offsetBK, offsetBN);
|
||||||
|
}
|
||||||
|
|
||||||
// swizzling
|
// swizzling
|
||||||
int perPhaseA = sharedLayout.getPerPhase();
|
|
||||||
int maxPhaseA = sharedLayout.getMaxPhase();
|
|
||||||
int perPhaseB = sharedLayout.getPerPhase();
|
int perPhaseB = sharedLayout.getPerPhase();
|
||||||
int maxPhaseB = sharedLayout.getMaxPhase();
|
int maxPhaseB = sharedLayout.getMaxPhase();
|
||||||
int stepB0 = isBRow ? strideRepN : strideRepK;
|
int stepB0 = isBRow ? strideRepN : strideRepK;
|
||||||
int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1);
|
int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1);
|
||||||
int NK = shape[0];
|
int NK = shape[0];
|
||||||
|
|
||||||
auto [_0, _1, offsetBN, offsetBK] =
|
|
||||||
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
|
|
||||||
if (transB)
|
|
||||||
std::swap(offsetBK, offsetBN);
|
|
||||||
|
|
||||||
Value offB0 = isBRow ? offsetBN : offsetBK;
|
Value offB0 = isBRow ? offsetBN : offsetBK;
|
||||||
Value offB1 = isBRow ? offsetBK : offsetBN;
|
Value offB1 = isBRow ? offsetBK : offsetBN;
|
||||||
Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB));
|
Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB));
|
||||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||||
|
|
||||||
offB0 = add(offB0, cSwizzleOffset);
|
offB0 = add(offB0, cSwizzleOffset);
|
||||||
SmallVector<Value> offB(numPtrB);
|
SmallVector<Value> offB(numPtrB);
|
||||||
for (int i = 0; i < numPtrB; ++i) {
|
for (int i = 0; i < numPtrB; ++i) {
|
||||||
@@ -1549,6 +1537,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
|
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
|
||||||
mul(i32_val(stepBK), strideBK));
|
mul(i32_val(stepBK), strideBK));
|
||||||
Value pb = gep(f16PtrTy, thePtrB, offset);
|
Value pb = gep(f16PtrTy, thePtrB, offset);
|
||||||
|
|
||||||
Value hb =
|
Value hb =
|
||||||
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
||||||
// record lds that needs to be moved
|
// record lds that needs to be moved
|
||||||
@@ -1651,9 +1640,12 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand(
|
|||||||
SmallVector<Value> elems =
|
SmallVector<Value> elems =
|
||||||
getElementsFromStruct(llStruct.getLoc(), llStruct, rewriter);
|
getElementsFromStruct(llStruct.getLoc(), llStruct, rewriter);
|
||||||
|
|
||||||
for (int k = 0, offset = 0, i = 0; k < NK && offset < elems.size();
|
int offset = 0;
|
||||||
k += 4, i++, offset += 2) {
|
for (int i = 0; offset < elems.size(); ++i) {
|
||||||
|
for (int k = 0; k < NK; k += 4) {
|
||||||
rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]);
|
rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]);
|
||||||
|
offset += 2;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return rcds;
|
return rcds;
|
||||||
@@ -1675,9 +1667,7 @@ Value DotOpFMAConversionHelper::loadA(
|
|||||||
int strideAK = isARow ? 1 : aShape[0];
|
int strideAK = isARow ? 1 : aShape[0];
|
||||||
int strideA0 = isARow ? strideAK : strideAM;
|
int strideA0 = isARow ? strideAK : strideAM;
|
||||||
int strideA1 = isARow ? strideAM : strideAK;
|
int strideA1 = isARow ? strideAM : strideAK;
|
||||||
int lda = isARow ? strideAM : strideAK;
|
|
||||||
int aNumPtr = 8;
|
int aNumPtr = 8;
|
||||||
int bNumPtr = 8;
|
|
||||||
int NK = aShape[1];
|
int NK = aShape[1];
|
||||||
|
|
||||||
auto shapePerCTA = getShapePerCTA(dLayout);
|
auto shapePerCTA = getShapePerCTA(dLayout);
|
||||||
@@ -1686,13 +1676,11 @@ Value DotOpFMAConversionHelper::loadA(
|
|||||||
Value _0 = i32_val(0);
|
Value _0 = i32_val(0);
|
||||||
|
|
||||||
Value mContig = i32_val(sizePerThread[order[1]]);
|
Value mContig = i32_val(sizePerThread[order[1]]);
|
||||||
Value nContig = i32_val(sizePerThread[order[0]]);
|
|
||||||
|
|
||||||
// threadId in blocked layout
|
// threadId in blocked layout
|
||||||
auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc);
|
auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc);
|
||||||
|
|
||||||
Value threadIdM = threadIds[0];
|
Value threadIdM = threadIds[0];
|
||||||
Value threadIdN = threadIds[1];
|
|
||||||
|
|
||||||
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
|
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
|
||||||
Value offA1 = isARow ? mul(threadIdM, mContig) : _0;
|
Value offA1 = isARow ? mul(threadIdM, mContig) : _0;
|
||||||
@@ -1745,7 +1733,6 @@ Value DotOpFMAConversionHelper::loadB(
|
|||||||
int strideBK = isBRow ? bShape[1] : 1;
|
int strideBK = isBRow ? bShape[1] : 1;
|
||||||
int strideB0 = isBRow ? strideBN : strideBK;
|
int strideB0 = isBRow ? strideBN : strideBK;
|
||||||
int strideB1 = isBRow ? strideBK : strideBN;
|
int strideB1 = isBRow ? strideBK : strideBN;
|
||||||
int ldb = isBRow ? strideBK : strideBN;
|
|
||||||
int bNumPtr = 8;
|
int bNumPtr = 8;
|
||||||
int NK = bShape[0];
|
int NK = bShape[0];
|
||||||
|
|
||||||
@@ -1754,7 +1741,6 @@ Value DotOpFMAConversionHelper::loadB(
|
|||||||
|
|
||||||
Value _0 = i32_val(0);
|
Value _0 = i32_val(0);
|
||||||
|
|
||||||
Value mContig = i32_val(sizePerThread[order[1]]);
|
|
||||||
Value nContig = i32_val(sizePerThread[order[0]]);
|
Value nContig = i32_val(sizePerThread[order[0]]);
|
||||||
|
|
||||||
// threadId in blocked layout
|
// threadId in blocked layout
|
||||||
|
@@ -62,12 +62,11 @@ namespace LLVM {
|
|||||||
static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
||||||
|
|
||||||
// A helper function for using printf in LLVM conversion.
|
// A helper function for using printf in LLVM conversion.
|
||||||
void llPrintf(StringRef msg, ValueRange args,
|
void vprintf(StringRef msg, ValueRange args,
|
||||||
ConversionPatternRewriter &rewriter);
|
ConversionPatternRewriter &rewriter);
|
||||||
|
|
||||||
// Helper function
|
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||||
#define tid_val() getThreadId(rewriter, loc)
|
std::string elem_repr, ConversionPatternRewriter &builder);
|
||||||
#define llprintf(fmt, ...) LLVM::llPrintf(fmt, {__VA_ARGS__}, rewriter)
|
|
||||||
|
|
||||||
} // namespace LLVM
|
} // namespace LLVM
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
@@ -3537,8 +3536,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
|||||||
SmallVector<Value> resVals(resSize);
|
SmallVector<Value> resVals(resSize);
|
||||||
|
|
||||||
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
|
||||||
auto ha = has[{m, k}];
|
auto ha = has.at({m, k});
|
||||||
auto hb = hbs[{n, k}];
|
auto hb = hbs.at({n, k});
|
||||||
std::vector<size_t> idx{{
|
std::vector<size_t> idx{{
|
||||||
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
|
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
|
||||||
(m * 2 + 0) + (n * 4 + 1) * numM,
|
(m * 2 + 0) + (n * 4 + 1) * numM,
|
||||||
@@ -3554,13 +3553,13 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
|||||||
|
|
||||||
auto *resOprs = builder.newListOperand(8, "=f");
|
auto *resOprs = builder.newListOperand(8, "=f");
|
||||||
auto *AOprs = builder.newListOperand({
|
auto *AOprs = builder.newListOperand({
|
||||||
{ha.first, "f"},
|
{ha.first, "r"},
|
||||||
{ha.second, "f"},
|
{ha.second, "r"},
|
||||||
});
|
});
|
||||||
|
|
||||||
auto *BOprs = builder.newListOperand({
|
auto *BOprs = builder.newListOperand({
|
||||||
{hb.first, "f"},
|
{hb.first, "r"},
|
||||||
{hb.second, "f"},
|
{hb.second, "r"},
|
||||||
});
|
});
|
||||||
auto *COprs = builder.newListOperand();
|
auto *COprs = builder.newListOperand();
|
||||||
for (int i = 0; i < 8; ++i)
|
for (int i = 0; i < 8; ++i)
|
||||||
@@ -4806,11 +4805,23 @@ namespace mlir {
|
|||||||
|
|
||||||
namespace LLVM {
|
namespace LLVM {
|
||||||
|
|
||||||
void llPrintf(StringRef msg, ValueRange args,
|
void vprintf(StringRef msg, ValueRange args,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
PrintfOpConversion::llPrintf(msg, args, rewriter);
|
PrintfOpConversion::llPrintf(msg, args, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||||
|
std::string elem_repr, ConversionPatternRewriter &builder) {
|
||||||
|
std::string fmt = info + " t-%d ";
|
||||||
|
std::vector<Value> new_arr({thread});
|
||||||
|
for (int i = 0; i < arr.size(); ++i) {
|
||||||
|
fmt += elem_repr + ((i == arr.size() - 1) ? "" : ", ");
|
||||||
|
new_arr.push_back(arr[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
vprintf(fmt, new_arr, builder);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace LLVM
|
} // namespace LLVM
|
||||||
|
|
||||||
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||||
|
@@ -111,6 +111,8 @@
|
|||||||
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
|
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
|
||||||
__VA_ARGS__)
|
__VA_ARGS__)
|
||||||
|
|
||||||
|
#define tid_val() getThreadId(rewriter, loc)
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace LLVM {
|
namespace LLVM {
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
@@ -756,6 +756,7 @@ public:
|
|||||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
int version = computeCapabilityToMMAVersion(computeCapability);
|
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||||
|
|
||||||
auto newRetType = RankedTensorType::get(
|
auto newRetType = RankedTensorType::get(
|
||||||
retShape, oldRetType.getElementType(),
|
retShape, oldRetType.getElementType(),
|
||||||
triton::gpu::MmaEncodingAttr::get(
|
triton::gpu::MmaEncodingAttr::get(
|
||||||
|
@@ -1383,6 +1383,11 @@ void init_triton_translation(py::module &m) {
|
|||||||
llvm::SMDiagnostic error;
|
llvm::SMDiagnostic error;
|
||||||
std::unique_ptr<llvm::Module> module =
|
std::unique_ptr<llvm::Module> module =
|
||||||
llvm::parseIR(buffer->getMemBufferRef(), error, context);
|
llvm::parseIR(buffer->getMemBufferRef(), error, context);
|
||||||
|
if (!module)
|
||||||
|
llvm::report_fatal_error(
|
||||||
|
"failed to parse IR: " + error.getMessage() +
|
||||||
|
"lineno: " + std::to_string(error.getLineNo()));
|
||||||
|
|
||||||
// translate module to PTX
|
// translate module to PTX
|
||||||
auto ptxCode =
|
auto ptxCode =
|
||||||
triton::translateLLVMIRToPTX(*module, capability, version);
|
triton::translateLLVMIRToPTX(*module, capability, version);
|
||||||
|
@@ -292,3 +292,21 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
|||||||
torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err))
|
torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err))
|
||||||
else:
|
else:
|
||||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE this is useful only on Volta GPU.
|
||||||
|
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
|
||||||
|
(shape, num_warps, trans_a, trans_b)
|
||||||
|
for shape in [
|
||||||
|
[16, 16, 16],
|
||||||
|
[16, 16, 32],
|
||||||
|
[32, 16, 16],
|
||||||
|
[32, 32, 32],
|
||||||
|
[128, 16, 16],
|
||||||
|
]
|
||||||
|
for num_warps in [1]
|
||||||
|
for trans_a in [False]
|
||||||
|
for trans_b in [False]
|
||||||
|
])
|
||||||
|
def test_gemm_no_scf_for_mmav1(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
|
||||||
|
test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B)
|
||||||
|
Reference in New Issue
Block a user