[OPTIMIZER] Update the versionMinor in MMA layout for volta (#1014)
Continue the work https://github.com/openai/triton/pull/990 # Background The `versionMinor` in MmaEncodingAttr holds some states of DotOp's operands in Volta, while such operands will be modified by some patterns, making the states out-of-date. This PR helps to correct the states. # Implementation It adds three new patterns: 1. `CollectMmaToUpdateForVolta` helps to collect and build a map holding the MmaEncodingAttr instances with wrong states and create new correct ones for them, 2. `UpdateMMAVersionMinorForVolta` helps to replace the Ops generating the wrong MmaEncodingAttr instances with new correct ones, currently it supports the following Ops a. `convert_layout[X -> mma]` b. `arith.constant SplatAttr : !tensor<mma>` c. `dot ... : !tensor<mma>` # Limitation This PR chooses the mapping way to bypass the IR walk complexity from the circular dependency between dot_operand[parent] and mma. We use the MmaEncodingAttr instance as the mapping key, but there might be multiple DotOp holding different DotOprand(IsMMAv1Row) that have the same wrong MmaEncodingAttr instance. To make each DotOp's (wrong) MmaEncodingAttr unique, we might need an ID field to MmaEncodingAttr.
This commit is contained in:
@@ -204,7 +204,12 @@ struct DotOpMmaV1ConversionHelper {
|
||||
offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
|
||||
}
|
||||
|
||||
Type f16x2Ty = vec_ty(f16_ty, 2);
|
||||
Type elemX2Ty = vec_ty(f16_ty, 2);
|
||||
Type elemPtrTy = ptr_ty(f16_ty);
|
||||
if (tensorTy.getElementType().isBF16()) {
|
||||
elemX2Ty = vec_ty(i16_ty, 2);
|
||||
elemPtrTy = ptr_ty(i16_ty);
|
||||
}
|
||||
|
||||
// prepare arguments
|
||||
SmallVector<Value> ptrA(numPtrA);
|
||||
@@ -213,30 +218,28 @@ struct DotOpMmaV1ConversionHelper {
|
||||
for (int i = 0; i < numPtrA; i++)
|
||||
ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]);
|
||||
|
||||
Type f16PtrTy = ptr_ty(f16_ty);
|
||||
|
||||
auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) {
|
||||
vals[{m, k}] = {val0, val1};
|
||||
};
|
||||
auto loadA = [&](int m, int k) {
|
||||
int offidx = (isARow ? k / 4 : m) % numPtrA;
|
||||
Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]);
|
||||
Value thePtrA = gep(elemPtrTy, smemBase, offA[offidx]);
|
||||
|
||||
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
||||
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
||||
Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
|
||||
mul(i32_val(stepAK), strideAK));
|
||||
Value pa = gep(f16PtrTy, thePtrA, offset);
|
||||
Value pa = gep(elemPtrTy, thePtrA, offset);
|
||||
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
|
||||
Value ha = load(bitcast(pa, aPtrTy));
|
||||
// record lds that needs to be moved
|
||||
Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty);
|
||||
Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty);
|
||||
Value ha00 = bitcast(extract_element(ha, i32_val(0)), elemX2Ty);
|
||||
Value ha01 = bitcast(extract_element(ha, i32_val(1)), elemX2Ty);
|
||||
ld(has, m, k, ha00, ha01);
|
||||
|
||||
if (vecA > 4) {
|
||||
Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty);
|
||||
Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty);
|
||||
Value ha10 = bitcast(extract_element(ha, i32_val(2)), elemX2Ty);
|
||||
Value ha11 = bitcast(extract_element(ha, i32_val(3)), elemX2Ty);
|
||||
if (isARow)
|
||||
ld(has, m, k + 4, ha10, ha11);
|
||||
else
|
||||
@@ -256,7 +259,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
elems.push_back(item.second.second);
|
||||
}
|
||||
|
||||
Type resTy = struct_ty(SmallVector<Type>(elems.size(), f16x2Ty));
|
||||
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
|
||||
Value res = getStructFromElements(loc, elems, rewriter, resTy);
|
||||
return res;
|
||||
}
|
||||
@@ -319,8 +322,12 @@ struct DotOpMmaV1ConversionHelper {
|
||||
offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
|
||||
}
|
||||
|
||||
Type f16PtrTy = ptr_ty(f16_ty);
|
||||
Type f16x2Ty = vec_ty(f16_ty, 2);
|
||||
Type elemPtrTy = ptr_ty(f16_ty);
|
||||
Type elemX2Ty = vec_ty(f16_ty, 2);
|
||||
if (tensorTy.getElementType().isBF16()) {
|
||||
elemPtrTy = ptr_ty(i16_ty);
|
||||
elemX2Ty = vec_ty(i16_ty, 2);
|
||||
}
|
||||
|
||||
SmallVector<Value> ptrB(numPtrB);
|
||||
ValueTable hbs;
|
||||
@@ -339,17 +346,17 @@ struct DotOpMmaV1ConversionHelper {
|
||||
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
|
||||
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
|
||||
mul(i32_val(stepBK), strideBK));
|
||||
Value pb = gep(f16PtrTy, thePtrB, offset);
|
||||
Value pb = gep(elemPtrTy, thePtrB, offset);
|
||||
|
||||
Value hb =
|
||||
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
||||
// record lds that needs to be moved
|
||||
Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty);
|
||||
Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty);
|
||||
Value hb00 = bitcast(extract_element(hb, i32_val(0)), elemX2Ty);
|
||||
Value hb01 = bitcast(extract_element(hb, i32_val(1)), elemX2Ty);
|
||||
ld(hbs, n, K, hb00, hb01);
|
||||
if (vecB > 4) {
|
||||
Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty);
|
||||
Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty);
|
||||
Value hb10 = bitcast(extract_element(hb, i32_val(2)), elemX2Ty);
|
||||
Value hb11 = bitcast(extract_element(hb, i32_val(3)), elemX2Ty);
|
||||
if (isBRow)
|
||||
ld(hbs, n + 1, K, hb10, hb11);
|
||||
else
|
||||
@@ -369,8 +376,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
elems.push_back(item.second.first);
|
||||
elems.push_back(item.second.second);
|
||||
}
|
||||
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
|
||||
Type resTy = struct_ty(SmallVector<Type>(elems.size(), fp16x2Ty));
|
||||
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
|
||||
Value res = getStructFromElements(loc, elems, rewriter, resTy);
|
||||
return res;
|
||||
}
|
||||
|
Reference in New Issue
Block a user