[Triton-MLIR][BACKEND] Tiny patch for MMAv1 and code clean (#964)
This PR: - Several fix on MMAV1 code - Remove the env `TRITON_STATIC_LOOP_UNROLLING` in v100 CI since the pipeline pass works now - some code clean
This commit is contained in:
@@ -1730,9 +1730,9 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
auto rhsVals = getElementsFromStruct(loc, adaptor.rhs(), rewriter);
|
||||
// concatenate (and potentially reorder) values
|
||||
SmallVector<Value> retVals;
|
||||
for(Value v: lhsVals)
|
||||
for (Value v : lhsVals)
|
||||
retVals.push_back(v);
|
||||
for(Value v: rhsVals)
|
||||
for (Value v : rhsVals)
|
||||
retVals.push_back(v);
|
||||
// pack and replace
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
@@ -3408,14 +3408,16 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||
isHMMA) { // tensor core v1
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
res =
|
||||
helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||
// operand $b
|
||||
res =
|
||||
helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
||||
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
|
||||
// TODO[Superjomn]: transA is not available here.
|
||||
bool transA = false;
|
||||
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc,
|
||||
rewriter);
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
|
||||
// TODO[Superjomn]: transB is not available here.
|
||||
bool transB = false;
|
||||
res = helper.loadB(src, transB, smemObj, getThreadId(rewriter, loc), loc,
|
||||
rewriter);
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported mma layout found");
|
||||
@@ -3537,6 +3539,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
bool isBRow = BOrder[0] != 0;
|
||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
||||
// TODO[Superjomn]: ld.v4 is not supported.
|
||||
isAVec4 = true;
|
||||
isBVec4 = true;
|
||||
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
@@ -3549,7 +3555,7 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]);
|
||||
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[0]);
|
||||
unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]);
|
||||
unsigned NK = AShape[1];
|
||||
|
||||
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
|
||||
@@ -3836,7 +3842,8 @@ public:
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||
auto ctx = type.getContext();
|
||||
Attribute layout = type.getEncoding();
|
||||
auto shape = type.getShape();
|
||||
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
|
||||
|
||||
if (layout &&
|
||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||
layout.isa<MmaEncodingAttr>())) {
|
||||
@@ -3899,13 +3906,22 @@ public:
|
||||
if (mmaLayout.getVersion() == 1) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
// TODO[Superjomn]: Both transA and transB are not available here.
|
||||
bool trans = false;
|
||||
// TODO[Superjomn]: The order of A and B are not available here.
|
||||
SmallVector<unsigned> order({1, 0});
|
||||
if (trans) {
|
||||
std::swap(shape[0], shape[1]);
|
||||
std::swap(order[0], order[1]);
|
||||
}
|
||||
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
int elems = helper.numElemsPerThreadA(type);
|
||||
int elems = helper.numElemsPerThreadA(shape, order);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||
int elems = helper.numElemsPerThreadB(type);
|
||||
int elems = helper.numElemsPerThreadB(shape, order);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
|
Reference in New Issue
Block a user