[Triton-MLIR][BACKEND] Tiny patch for MMAv1 and code clean (#964)
This PR: - Several fix on MMAV1 code - Remove the env `TRITON_STATIC_LOOP_UNROLLING` in v100 CI since the pipeline pass works now - some code clean
This commit is contained in:
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
@@ -88,9 +88,7 @@ jobs:
|
|||||||
- name: Run python tests on V100
|
- name: Run python tests on V100
|
||||||
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
|
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
|
||||||
run: |
|
run: |
|
||||||
# TODO[Superjomn]: Remove the forloop-unroll setting after pipeline pass works
|
|
||||||
cd python/tests
|
cd python/tests
|
||||||
export TRITON_STATIC_LOOP_UNROLLING=1
|
|
||||||
pytest test_gemm.py::test_gemm_for_mmav1
|
pytest test_gemm.py::test_gemm_for_mmav1
|
||||||
|
|
||||||
- name: Run CXX unittests
|
- name: Run CXX unittests
|
||||||
|
@@ -43,12 +43,51 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
|
|||||||
struct DotOpMmaV1ConversionHelper {
|
struct DotOpMmaV1ConversionHelper {
|
||||||
MmaEncodingAttr mmaLayout;
|
MmaEncodingAttr mmaLayout;
|
||||||
ArrayRef<unsigned> wpt;
|
ArrayRef<unsigned> wpt;
|
||||||
|
static constexpr std::array<int, 3> fpw{{2, 2, 1}};
|
||||||
|
|
||||||
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
|
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;
|
||||||
|
|
||||||
explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout)
|
explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout)
|
||||||
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
|
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
|
||||||
|
|
||||||
|
// Help to share some variables across multiple functions for A.
|
||||||
|
struct AParam {
|
||||||
|
SmallVector<int> rep;
|
||||||
|
SmallVector<int> spw;
|
||||||
|
|
||||||
|
// TODO[Superjomn]: Support the case when isAVec4=false later
|
||||||
|
// Currently, we only support ld.v2, for the mma layout varies with
|
||||||
|
// different ld vector width.
|
||||||
|
// bool isAVec4 = !isARow && shapeTransed[orderTransed[0]] <= 16;
|
||||||
|
const bool isAVec4{true};
|
||||||
|
|
||||||
|
explicit AParam(bool isARow) {
|
||||||
|
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||||
|
int repM = 2 * packSize0;
|
||||||
|
int repK = 1;
|
||||||
|
int spwM = fpw[0] * 4 * repM;
|
||||||
|
rep.assign({repM, 0, repK});
|
||||||
|
spw.assign({spwM, 0, 1});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Help to share some variables across multiple functions for A.
|
||||||
|
struct BParam {
|
||||||
|
SmallVector<int> rep;
|
||||||
|
SmallVector<int> spw;
|
||||||
|
// TODO[Superjomn]: Support the case when isBVec4=false later
|
||||||
|
// Currently, we only support ld.v2, for the mma layout varies with
|
||||||
|
// different ld vector width.
|
||||||
|
// bool isBVec4 = isBRow && shapeTransed[orderTransed[0]] <= 16;
|
||||||
|
const bool isBVec4{true};
|
||||||
|
|
||||||
|
explicit BParam(bool isBRow) {
|
||||||
|
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||||
|
rep.assign({0, 2 * packSize1, 1});
|
||||||
|
spw.assign({0, fpw[1] * 4 * rep[1], 1});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
int getRepM(int M) const {
|
int getRepM(int M) const {
|
||||||
return std::max<int>(M / (wpt[0] * instrShape[0]), 1);
|
return std::max<int>(M / (wpt[0] * instrShape[0]), 1);
|
||||||
}
|
}
|
||||||
@@ -65,29 +104,34 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
return struct_ty(SmallVector<Type>{8, fp32Ty});
|
return struct_ty(SmallVector<Type>{8, fp32Ty});
|
||||||
}
|
}
|
||||||
|
|
||||||
// number of fp16x2 elements for $a.
|
// Get the number of fp16x2 elements for $a.
|
||||||
int numElemsPerThreadA(RankedTensorType tensorTy) const {
|
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||||
auto shape = tensorTy.getShape();
|
// \param orderTransed: the order or reordered order if transpose needed.
|
||||||
auto order = getOrder();
|
unsigned getNumM(ArrayRef<int64_t> shapeTransed,
|
||||||
|
ArrayRef<unsigned> orderTransed) const {
|
||||||
|
bool isARow = orderTransed[0] != 0;
|
||||||
|
AParam param(isARow);
|
||||||
|
|
||||||
bool isARow = order[0] != 0;
|
unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]);
|
||||||
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
return numM;
|
||||||
// TODO[Superjomn]: Support the case when isAVec4=false later
|
}
|
||||||
// Currently, we only support ld.v2, for the mma layout varies with
|
|
||||||
// different ld vector width.
|
|
||||||
isAVec4 = true;
|
|
||||||
|
|
||||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
// Get the number of fp16x2 elements for $b.
|
||||||
|
// \param shapeTransed: the shape or reordered shape if transpose needed.
|
||||||
|
// \param orderTransed: the order or reordered order if transpose needed.
|
||||||
|
unsigned getNumN(ArrayRef<int64_t> shapeTransed,
|
||||||
|
ArrayRef<unsigned> orderTransed) const {
|
||||||
|
bool isBRow = orderTransed[0] != 0;
|
||||||
|
BParam param(isBRow);
|
||||||
|
|
||||||
SmallVector<int> fpw({2, 2, 1});
|
unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]);
|
||||||
int repM = 2 * packSize0;
|
return numN;
|
||||||
int repK = 1;
|
}
|
||||||
int spwM = fpw[0] * 4 * repM;
|
|
||||||
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
|
|
||||||
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
|
|
||||||
|
|
||||||
int NK = shape[1];
|
int numElemsPerThreadA(ArrayRef<int64_t> shapeTransed,
|
||||||
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
|
ArrayRef<unsigned> orderTransed) const {
|
||||||
|
int numM = getNumM(shapeTransed, orderTransed);
|
||||||
|
int NK = shapeTransed[1];
|
||||||
|
|
||||||
// NOTE: We couldn't get the vec from the shared layout.
|
// NOTE: We couldn't get the vec from the shared layout.
|
||||||
// int vecA = sharedLayout.getVec();
|
// int vecA = sharedLayout.getVec();
|
||||||
@@ -97,39 +141,27 @@ struct DotOpMmaV1ConversionHelper {
|
|||||||
return (numM / 2) * (NK / 4) * elemsPerLd;
|
return (numM / 2) * (NK / 4) * elemsPerLd;
|
||||||
}
|
}
|
||||||
|
|
||||||
// number of fp16x2 elements for $b.
|
int numElemsPerThreadB(ArrayRef<int64_t> shapeTransed,
|
||||||
int numElemsPerThreadB(RankedTensorType tensorTy) const {
|
ArrayRef<unsigned> orderTransed) const {
|
||||||
auto shape = tensorTy.getShape();
|
unsigned numN = getNumN(shapeTransed, orderTransed);
|
||||||
auto order = getOrder();
|
int NK = shapeTransed[0];
|
||||||
bool isBRow = order[0] != 0;
|
|
||||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
|
||||||
// TODO[Superjomn]: Support the case when isBVec4=false later
|
|
||||||
// Currently, we only support ld.v2, for the mma layout varies with
|
|
||||||
// different ld vector width.
|
|
||||||
isBVec4 = true;
|
|
||||||
|
|
||||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
|
||||||
SmallVector<int> fpw({2, 2, 1});
|
|
||||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
|
||||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
|
||||||
// NOTE: We couldn't get the vec from the shared layout.
|
// NOTE: We couldn't get the vec from the shared layout.
|
||||||
// int vecB = sharedLayout.getVec();
|
// int vecB = sharedLayout.getVec();
|
||||||
// TODO[Superjomn]: Consider the case when vecA > 4
|
// TODO[Superjomn]: Consider the case when vecA > 4
|
||||||
bool vecGt4 = false;
|
bool vecGt4 = false;
|
||||||
int elemsPerLd = vecGt4 ? 4 : 2;
|
int elemsPerLd = vecGt4 ? 4 : 2;
|
||||||
int NK = shape[0];
|
|
||||||
|
|
||||||
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
|
|
||||||
return (numN / 2) * (NK / 4) * elemsPerLd;
|
return (numN / 2) * (NK / 4) * elemsPerLd;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||||
Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread,
|
Value loadA(Value A, bool transA, const SharedMemoryObject &smemObj,
|
||||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
Value thread, Location loc,
|
||||||
|
ConversionPatternRewriter &rewriter) const;
|
||||||
|
|
||||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||||
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
|
Value loadB(Value B, bool transB, const SharedMemoryObject &smemObj,
|
||||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
Value thread, Location loc,
|
||||||
|
ConversionPatternRewriter &rewriter) const;
|
||||||
|
|
||||||
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
|
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
|
||||||
|
|
||||||
@@ -1321,8 +1353,8 @@ struct DotOpFMAConversionHelper {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Value DotOpMmaV1ConversionHelper::loadA(
|
Value DotOpMmaV1ConversionHelper::loadA(
|
||||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
Location loc, ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
auto *ctx = rewriter.getContext();
|
auto *ctx = rewriter.getContext();
|
||||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||||
@@ -1336,24 +1368,11 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||||
|
|
||||||
bool isARow = order[0] != 0;
|
bool isARow = order[0] != 0;
|
||||||
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
AParam param(isARow);
|
||||||
// TODO[Superjomn]: Support the case when isAVec4=false later
|
|
||||||
// Currently, we only support ld.v2, for the mma layout varies with different
|
|
||||||
// ld vector width.
|
|
||||||
isAVec4 = true;
|
|
||||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
|
||||||
|
|
||||||
SmallVector<int> fpw({2, 2, 1});
|
auto [offsetAM, offsetAK, _0, _1] = computeOffsets(
|
||||||
int repM = 2 * packSize0;
|
thread, isARow, false, fpw, param.spw, param.rep, rewriter, loc);
|
||||||
int repK = 1;
|
|
||||||
int spwM = fpw[0] * 4 * repM;
|
|
||||||
SmallVector<int> rep({repM, 0, repK}); // pad N with 0
|
|
||||||
SmallVector<int> spw({spwM, 0, 1}); // pad N with 0
|
|
||||||
|
|
||||||
auto [offsetAM, offsetAK, _0, _1] =
|
|
||||||
computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc);
|
|
||||||
// TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp.
|
|
||||||
bool transA = false;
|
|
||||||
if (transA) {
|
if (transA) {
|
||||||
std::swap(shape[0], shape[1]);
|
std::swap(shape[0], shape[1]);
|
||||||
std::swap(offsetAM, offsetAK);
|
std::swap(offsetAM, offsetAK);
|
||||||
@@ -1401,8 +1420,6 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
for (int i = 0; i < numPtrA; i++)
|
for (int i = 0; i < numPtrA; i++)
|
||||||
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
|
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
|
||||||
|
|
||||||
unsigned numM = std::max<int>(rep[0] * shape[0] / (spw[0] * wpt[0]), 1);
|
|
||||||
|
|
||||||
Type f16PtrTy = ptr_ty(f16_ty);
|
Type f16PtrTy = ptr_ty(f16_ty);
|
||||||
|
|
||||||
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
|
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
|
||||||
@@ -1434,6 +1451,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
unsigned numM = getNumM(shape, order);
|
||||||
for (unsigned k = 0; k < NK; k += 4)
|
for (unsigned k = 0; k < NK; k += 4)
|
||||||
for (unsigned m = 0; m < numM / 2; ++m)
|
for (unsigned m = 0; m < numM / 2; ++m)
|
||||||
loadA(m, k);
|
loadA(m, k);
|
||||||
@@ -1451,8 +1469,8 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value DotOpMmaV1ConversionHelper::loadB(
|
Value DotOpMmaV1ConversionHelper::loadB(
|
||||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
Value tensor, bool transB, const SharedMemoryObject &smemObj, Value thread,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
Location loc, ConversionPatternRewriter &rewriter) const {
|
||||||
// smem
|
// smem
|
||||||
auto strides = smemObj.strides;
|
auto strides = smemObj.strides;
|
||||||
|
|
||||||
@@ -1467,17 +1485,9 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
|
|
||||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||||
bool isBRow = order[0] != 0;
|
bool isBRow = order[0] != 0;
|
||||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
BParam param(isBRow);
|
||||||
// TODO[Superjomn]: Support the case when isBVec4=false later
|
|
||||||
// Currently, we only support ld.v2, for the mma layout varies with different
|
|
||||||
// ld vector width.
|
|
||||||
isBVec4 = true;
|
|
||||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
|
||||||
SmallVector<int> fpw({2, 2, 1});
|
|
||||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
|
||||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
|
||||||
int vecB = sharedLayout.getVec();
|
|
||||||
|
|
||||||
|
int vecB = sharedLayout.getVec();
|
||||||
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
||||||
Value strideBK = isBRow ? strides[0] : i32_val(1);
|
Value strideBK = isBRow ? strides[0] : i32_val(1);
|
||||||
Value strideB0 = isBRow ? strideBN : strideBK;
|
Value strideB0 = isBRow ? strideBN : strideBK;
|
||||||
@@ -1485,11 +1495,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
int strideRepN = wpt[1] * fpw[1] * 8;
|
int strideRepN = wpt[1] * fpw[1] * 8;
|
||||||
int strideRepK = 1;
|
int strideRepK = 1;
|
||||||
|
|
||||||
// TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp.
|
auto [_0, _1, offsetBN, offsetBK] = computeOffsets(
|
||||||
bool transB = false;
|
thread, false, isBRow, fpw, param.spw, param.rep, rewriter, loc);
|
||||||
|
|
||||||
auto [_0, _1, offsetBN, offsetBK] =
|
|
||||||
computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc);
|
|
||||||
if (transB) {
|
if (transB) {
|
||||||
std::swap(order[0], order[1]);
|
std::swap(order[0], order[1]);
|
||||||
std::swap(shape[0], shape[1]);
|
std::swap(shape[0], shape[1]);
|
||||||
@@ -1556,7 +1563,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]);
|
unsigned numN = getNumN(shape, order);
|
||||||
for (unsigned k = 0; k < NK; k += 4)
|
for (unsigned k = 0; k < NK; k += 4)
|
||||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||||
if (!hbs.count({n, k}))
|
if (!hbs.count({n, k}))
|
||||||
|
@@ -3408,14 +3408,16 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
|||||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||||
isHMMA) { // tensor core v1
|
isHMMA) { // tensor core v1
|
||||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||||
if (dotOperandLayout.getOpIdx() == 0) {
|
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
||||||
// operand $a
|
// TODO[Superjomn]: transA is not available here.
|
||||||
res =
|
bool transA = false;
|
||||||
helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
|
||||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
rewriter);
|
||||||
// operand $b
|
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
|
||||||
res =
|
// TODO[Superjomn]: transB is not available here.
|
||||||
helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
bool transB = false;
|
||||||
|
res = helper.loadB(src, transB, smemObj, getThreadId(rewriter, loc), loc,
|
||||||
|
rewriter);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
assert(false && "Unsupported mma layout found");
|
assert(false && "Unsupported mma layout found");
|
||||||
@@ -3537,6 +3539,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
|||||||
bool isBRow = BOrder[0] != 0;
|
bool isBRow = BOrder[0] != 0;
|
||||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||||
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
||||||
|
// TODO[Superjomn]: ld.v4 is not supported.
|
||||||
|
isAVec4 = true;
|
||||||
|
isBVec4 = true;
|
||||||
|
|
||||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||||
SmallVector<int> fpw({2, 2, 1});
|
SmallVector<int> fpw({2, 2, 1});
|
||||||
@@ -3549,7 +3555,7 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
|||||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||||
|
|
||||||
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
|
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
|
||||||
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[0]);
|
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]);
|
||||||
unsigned NK = AShape[1];
|
unsigned NK = AShape[1];
|
||||||
|
|
||||||
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
|
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
|
||||||
@@ -3836,7 +3842,8 @@ public:
|
|||||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||||
auto ctx = type.getContext();
|
auto ctx = type.getContext();
|
||||||
Attribute layout = type.getEncoding();
|
Attribute layout = type.getEncoding();
|
||||||
auto shape = type.getShape();
|
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
|
||||||
|
|
||||||
if (layout &&
|
if (layout &&
|
||||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||||
layout.isa<MmaEncodingAttr>())) {
|
layout.isa<MmaEncodingAttr>())) {
|
||||||
@@ -3899,13 +3906,22 @@ public:
|
|||||||
if (mmaLayout.getVersion() == 1) {
|
if (mmaLayout.getVersion() == 1) {
|
||||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||||
|
|
||||||
|
// TODO[Superjomn]: Both transA and transB are not available here.
|
||||||
|
bool trans = false;
|
||||||
|
// TODO[Superjomn]: The order of A and B are not available here.
|
||||||
|
SmallVector<unsigned> order({1, 0});
|
||||||
|
if (trans) {
|
||||||
|
std::swap(shape[0], shape[1]);
|
||||||
|
std::swap(order[0], order[1]);
|
||||||
|
}
|
||||||
|
|
||||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||||
int elems = helper.numElemsPerThreadA(type);
|
int elems = helper.numElemsPerThreadA(shape, order);
|
||||||
Type x2Ty = vec_ty(elemTy, 2);
|
Type x2Ty = vec_ty(elemTy, 2);
|
||||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||||
}
|
}
|
||||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||||
int elems = helper.numElemsPerThreadB(type);
|
int elems = helper.numElemsPerThreadB(shape, order);
|
||||||
Type x2Ty = vec_ty(elemTy, 2);
|
Type x2Ty = vec_ty(elemTy, 2);
|
||||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user