[OPTIMIZER] Update the versionMinor in MMA layout for volta (#1014)

Continue the work https://github.com/openai/triton/pull/990

# Background
The `versionMinor` in MmaEncodingAttr holds some states of DotOp's
operands in Volta, while such operands will be modified by some
patterns, making the states out-of-date.

This PR helps to correct the states.

# Implementation
It adds three new patterns:

1. `CollectMmaToUpdateForVolta` helps to collect and build a map holding
the MmaEncodingAttr instances with wrong states and create new correct
ones for them,
2. `UpdateMMAVersionMinorForVolta` helps to replace the Ops generating
the wrong MmaEncodingAttr instances with new correct ones, currently it
supports the following Ops
    a. `convert_layout[X -> mma]`
    b. `arith.constant SplatAttr : !tensor<mma>`
    c. `dot ... : !tensor<mma>`

# Limitation
This PR chooses the mapping way to bypass the IR walk complexity from
the circular dependency between dot_operand[parent] and mma.
We use the MmaEncodingAttr instance as the mapping key, but there might
be multiple DotOp holding different DotOprand(IsMMAv1Row) that have the
same wrong MmaEncodingAttr instance.
To make each DotOp's (wrong) MmaEncodingAttr unique, we might need an ID
field to MmaEncodingAttr.
This commit is contained in:
Yan Chunwei
2022-12-28 12:24:01 +08:00
committed by GitHub
parent fd2da4aff6
commit 2ba74d2729
3 changed files with 281 additions and 26 deletions

View File

@@ -204,7 +204,12 @@ struct DotOpMmaV1ConversionHelper {
offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
}
Type f16x2Ty = vec_ty(f16_ty, 2);
Type elemX2Ty = vec_ty(f16_ty, 2);
Type elemPtrTy = ptr_ty(f16_ty);
if (tensorTy.getElementType().isBF16()) {
elemX2Ty = vec_ty(i16_ty, 2);
elemPtrTy = ptr_ty(i16_ty);
}
// prepare arguments
SmallVector<Value> ptrA(numPtrA);
@@ -213,30 +218,28 @@ struct DotOpMmaV1ConversionHelper {
for (int i = 0; i < numPtrA; i++)
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
Type f16PtrTy = ptr_ty(f16_ty);
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
vals[{m, k}] = {val0, val1};
};
auto loadA = [&](int m, int k) {
int offidx = (isARow ? k / 4 : m) % numPtrA;
Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]);
Value thePtrA = gep(elemPtrTy, smemBase, offA[offidx]);
int stepAM = isARow ? m : m / numPtrA * numPtrA;
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
mul(i32_val(stepAK), strideAK));
Value pa = gep(f16PtrTy, thePtrA, offset);
Value pa = gep(elemPtrTy, thePtrA, offset);
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
Value ha = load(bitcast(pa, aPtrTy));
// record lds that needs to be moved
Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty);
Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty);
Value ha00 = bitcast(extract_element(ha, i32_val(0)), elemX2Ty);
Value ha01 = bitcast(extract_element(ha, i32_val(1)), elemX2Ty);
ld(has, m, k, ha00, ha01);
if (vecA > 4) {
Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty);
Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty);
Value ha10 = bitcast(extract_element(ha, i32_val(2)), elemX2Ty);
Value ha11 = bitcast(extract_element(ha, i32_val(3)), elemX2Ty);
if (isARow)
ld(has, m, k + 4, ha10, ha11);
else
@@ -256,7 +259,7 @@ struct DotOpMmaV1ConversionHelper {
elems.push_back(item.second.second);
}
Type resTy = struct_ty(SmallVector<Type>(elems.size(), f16x2Ty));
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
Value res = getStructFromElements(loc, elems, rewriter, resTy);
return res;
}
@@ -319,8 +322,12 @@ struct DotOpMmaV1ConversionHelper {
offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
}
Type f16PtrTy = ptr_ty(f16_ty);
Type f16x2Ty = vec_ty(f16_ty, 2);
Type elemPtrTy = ptr_ty(f16_ty);
Type elemX2Ty = vec_ty(f16_ty, 2);
if (tensorTy.getElementType().isBF16()) {
elemPtrTy = ptr_ty(i16_ty);
elemX2Ty = vec_ty(i16_ty, 2);
}
SmallVector<Value> ptrB(numPtrB);
ValueTable hbs;
@@ -339,17 +346,17 @@ struct DotOpMmaV1ConversionHelper {
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
mul(i32_val(stepBK), strideBK));
Value pb = gep(f16PtrTy, thePtrB, offset);
Value pb = gep(elemPtrTy, thePtrB, offset);
Value hb =
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
// record lds that needs to be moved
Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty);
Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty);
Value hb00 = bitcast(extract_element(hb, i32_val(0)), elemX2Ty);
Value hb01 = bitcast(extract_element(hb, i32_val(1)), elemX2Ty);
ld(hbs, n, K, hb00, hb01);
if (vecB > 4) {
Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty);
Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty);
Value hb10 = bitcast(extract_element(hb, i32_val(2)), elemX2Ty);
Value hb11 = bitcast(extract_element(hb, i32_val(3)), elemX2Ty);
if (isBRow)
ld(hbs, n + 1, K, hb10, hb11);
else
@@ -369,8 +376,7 @@ struct DotOpMmaV1ConversionHelper {
elems.push_back(item.second.first);
elems.push_back(item.second.second);
}
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
Type resTy = struct_ty(SmallVector<Type>(elems.size(), fp16x2Ty));
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
Value res = getStructFromElements(loc, elems, rewriter, resTy);
return res;
}