[Triton-MLIR][BACKEND] Make mmav1 works on basic cases (#944)

TODO:

- Add more cases
- Currently, we just set vec to 4 to make the basic cases pass

Issue:

- the vec in shared layout is different compared to master branch
- when vec=1, it encounters CUDA misalignment error, it doesn't work in
master branch as well
- when setting vec to the value identical to master branch, the MMA
works
This commit is contained in:
Yan Chunwei
2022-12-06 10:57:08 +08:00
committed by GitHub
parent 189491727a
commit e419781978
8 changed files with 134 additions and 100 deletions

View File

@@ -87,20 +87,24 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
// number of rows per phase
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order`
unsigned inner = (opIdx == 0) ? 0 : 1;
// ---- begin version 1 ----
if (version == 1) {
bool is_row = order[0] != 0;
bool is_vec4 = opIdx == 0 ? is_row && (shape[order[0]] <= 16) :
!is_row && (shape[order[0]] <= 16);
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
is_row && (shape[order[0]] <= 16);
// TODO[Superjomn]: Support the case when is_vec4=false later
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
is_vec4 = true;
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
((is_row && !is_vec4) ? 2 : 1);
int rep = 2 * pack_size;
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
return $_get(context, 2 * rep, perPhase, maxPhase, order);
int vec = 2 * rep;
return $_get(context, vec, perPhase, maxPhase, order);
}
// ---- begin version 2 ----
@@ -110,14 +114,14 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
// for now, disable swizzle when using transposed int8 tensor cores
if (eltTy.isInteger(8) && order[0] == inner)
return $_get(context, 1, 1, 1, order);
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
}
}
// --- handle B operand ---
if (opIdx == 1) {
@@ -125,8 +129,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
}
}
llvm_unreachable("invalid operand index");
}