[BACKEND] Porting the legacy heuristic rule in assigning shared layout for A/B of MMAv1 (#948)
This commit is contained in:
@@ -92,11 +92,15 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
|||||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||||
|
|
||||||
// ---- begin version 1 ----
|
// ---- begin version 1 ----
|
||||||
// TODO: handle rep (see
|
|
||||||
// https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209)
|
|
||||||
if (version == 1) {
|
if (version == 1) {
|
||||||
|
bool is_row = order[0] != 0;
|
||||||
|
bool is_vec4 = opIdx == 0 ? is_row && (shape[order[0]] <= 16) :
|
||||||
|
!is_row && (shape[order[0]] <= 16);
|
||||||
|
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
|
||||||
|
((is_row && !is_vec4) ? 2 : 1);
|
||||||
|
int rep = 2 * pack_size;
|
||||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||||
return $_get(context, 1, perPhase, maxPhase, order);
|
return $_get(context, 2 * rep, perPhase, maxPhase, order);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---- begin version 2 ----
|
// ---- begin version 2 ----
|
||||||
|
@@ -3035,7 +3035,6 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
currVal = zext(llvmElemTy, currVal);
|
currVal = zext(llvmElemTy, currVal);
|
||||||
else if (isPtr)
|
else if (isPtr)
|
||||||
currVal = ptrtoint(llvmElemTy, currVal);
|
currVal = ptrtoint(llvmElemTy, currVal);
|
||||||
|
|
||||||
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
|
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
|
||||||
}
|
}
|
||||||
store(valVec, ptr);
|
store(valVec, ptr);
|
||||||
@@ -3143,7 +3142,6 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
Value src = op.src();
|
Value src = op.src();
|
||||||
Value dst = op.result();
|
Value dst = op.result();
|
||||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||||
|
Reference in New Issue
Block a user