diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index bc848eaf4..7a0cd4324 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -92,12 +92,16 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / unsigned inner = (opIdx == 0) ? 0 : 1; // ---- begin version 1 ---- - // TODO: handle rep (see - // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209) if (version == 1) { + bool is_row = order[0] != 0; + bool is_vec4 = opIdx == 0 ? is_row && (shape[order[0]] <= 16) : + !is_row && (shape[order[0]] <= 16); + int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) : + ((is_row && !is_vec4) ? 2 : 1); + int rep = 2 * pack_size; int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; - return $_get(context, 1, perPhase, maxPhase, order); - } + return $_get(context, 2 * rep, perPhase, maxPhase, order); + } // ---- begin version 2 ---- if (version == 2) { diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 8253d1786..19cfce7e0 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3035,7 +3035,6 @@ void ConvertLayoutOpConversion::processReplica( currVal = zext(llvmElemTy, currVal); else if (isPtr) currVal = ptrtoint(llvmElemTy, currVal); - valVec = insert_element(vecTy, valVec, currVal, idx_val(v)); } store(valVec, ptr); @@ -3143,7 +3142,6 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); - Value src = op.src(); Value dst = op.result(); auto srcTy = src.getType().cast();