diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 51203e0b9..d6b7e16b6 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -74,11 +74,9 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / let builders = [ AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, "ArrayRef":$shape, + "ArrayRef":$order, "Type":$eltTy), [{ auto mmaEnc = dotOpEnc.getParent().dyn_cast(); - // Only support row major for now - // TODO(Keren): check why column major code crashes - SmallVector order = {1, 0}; if(!mmaEnc) return $_get(context, 1, 1, 1, order); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 00129325d..0d0b62280 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -464,21 +464,6 @@ struct SharedMemoryObject { } } - // XXX(Keren): a special allocator for 3d tensors. It's a workaround for - // now since we don't have a correct way to encoding 3d tensors in the - // pipeline pass. - SharedMemoryObject(Value base, ArrayRef shape, Location loc, - ConversionPatternRewriter &rewriter) - : base(base) { - auto stride = 1; - for (auto dim : llvm::reverse(shape)) { - strides.emplace_back(i32_val(stride)); - offsets.emplace_back(i32_val(0)); - stride *= dim; - } - strides = llvm::to_vector<4>(llvm::reverse(strides)); - } - SmallVector getElems() const { SmallVector elems; elems.push_back(base); @@ -2251,8 +2236,18 @@ struct AllocTensorOpConversion getTypeConverter()->convertType(resultTy.getElementType()); auto elemPtrTy = ptr_ty(llvmElemTy, 3); smemBase = bitcast(smemBase, elemPtrTy); + auto order = resultTy.getEncoding().cast().getOrder(); + // workaround for 3D tensors + // TODO: We need to modify the pipeline pass to give a proper shared encoding to 3D tensors + SmallVector newOrder; + if (resultTy.getShape().size() == 3) + newOrder = {1 + order[0], 1 + order[1], 0}; + else + newOrder = SmallVector(order.begin(), order.end()); + + auto smemObj = - SharedMemoryObject(smemBase, resultTy.getShape(), loc, rewriter); + SharedMemoryObject(smemBase, resultTy.getShape(), newOrder, loc, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); rewriter.replaceOp(op, retVal); return success(); @@ -2302,6 +2297,10 @@ struct ExtractSliceOpConversion strideVals.emplace_back(smemObj.strides[i]); } } + + // llvm::outs() << "extract slice\n"; + // llvm::outs() << strideVals[0] << " " << smemObj.strides[1] << "\n"; + // llvm::outs() << strideVals[1] << " " << smemObj.strides[2] << "\n"; auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto elemPtrTy = ptr_ty(llvmElemTy, 3); auto resTy = op.getType().dyn_cast(); @@ -3262,8 +3261,8 @@ public: cMatShape = matShape[order[0]]; sMatShape = matShape[order[1]]; - cStride = smemStrides[1]; - sStride = smemStrides[0]; + cStride = smemStrides[order[0]]; + sStride = smemStrides[order[1]]; // rule: k must be the fast-changing axis. needTrans = kOrder != order[0]; @@ -6202,6 +6201,7 @@ private: dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp, srcType.getShape(), + getOrder(srcBlocked), srcType.getElementType())); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getOperand()); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 50be0b691..cf62ff578 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -201,7 +201,9 @@ LogicalResult LoopPipeliner::initialize() { ty.getShape().end()); bufferShape.insert(bufferShape.begin(), numStages); auto sharedEnc = ttg::SharedEncodingAttr::get( - ty.getContext(), dotOpEnc, ty.getShape(), ty.getElementType()); + ty.getContext(), dotOpEnc, ty.getShape(), + triton::gpu::getOrder(ty.getEncoding()), + ty.getElementType()); loadsBufferType[loadOp] = RankedTensorType::get( bufferShape, ty.getElementType(), sharedEnc); } diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index bd8ddd27c..f06d3eda0 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -186,9 +186,9 @@ def get_proper_err(a, b, golden): [128, 256, 128, 4, 128, 256, 32, False, False], [256, 128, 64, 4, 256, 128, 16, False, False], [128, 64, 128, 4, 128, 64, 32, False, False], - # TODO[goostavz]: fix these cases - #[128, 64, 128, 4, 128, 64, 32, True, False], - #[128, 64, 128, 4, 128, 64, 32, False, True], + # trans + [128, 64, 128, 4, 128, 64, 32, True, False], + [128, 64, 128, 4, 128, 64, 32, False, True], ]) def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B): if (TRANS_A): diff --git a/python/triton/compiler.py b/python/triton/compiler.py index e6c595ea7..d74b1f4fd 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -882,6 +882,7 @@ def ttir_to_ttgir(mod, num_warps, num_stages): pm.enable_debug() # Convert blocked layout to mma layout for dot ops so that pipeline # can get shared memory swizzled correctly. + pm.add_coalesce_pass() pm.add_triton_gpu_combine_pass() pm.add_tritongpu_pipeline_pass(num_stages) # Prefetch must be done after pipeline pass because pipeline pass @@ -889,7 +890,6 @@ def ttir_to_ttgir(mod, num_warps, num_stages): pm.add_tritongpu_prefetch_pass() pm.add_canonicalizer_pass() pm.add_cse_pass() - pm.add_coalesce_pass() pm.add_triton_gpu_combine_pass() pm.add_licm_pass() pm.add_triton_gpu_combine_pass() diff --git a/unittest/Dialect/TritonGPU/SwizzleTest.cpp b/unittest/Dialect/TritonGPU/SwizzleTest.cpp index ea2109552..58c43ade6 100644 --- a/unittest/Dialect/TritonGPU/SwizzleTest.cpp +++ b/unittest/Dialect/TritonGPU/SwizzleTest.cpp @@ -34,7 +34,7 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) { // create element type Type eltType = IntegerType::get(&ctx, params.typeWidth); - auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, eltType); + auto layout = SharedEncodingAttr::get(&ctx, encoding, params.shape, {1, 0}, eltType); ASSERT_EQ(layout.getVec(), params.refSwizzle.vec); ASSERT_EQ(layout.getPerPhase(), params.refSwizzle.perPhase);