From ecd1bc33df21889b474267114627d2cbb5658e9c Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 23 Sep 2022 12:38:14 -0700 Subject: [PATCH] [Triton-MLIR] Keren/code gen for extract slice and alloc tensor (#692) Co-authored-by: gzhu --- lib/Analysis/Allocation.cpp | 14 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 125 ++++++++++++------ lib/Dialect/TritonGPU/IR/Dialect.cpp | 7 +- test/Analysis/test-allocation.mlir | 8 +- test/Conversion/tritongpu_to_llvm.mlir | 38 ++++++ 5 files changed, 134 insertions(+), 58 deletions(-) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 0ea29afdc..6afa7ea1a 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -43,7 +43,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, return 0; } }; - // blocked -> blocked if (srcLayout.isa() && dstLayout.isa()) { auto srcBlockedLayout = srcLayout.cast(); @@ -66,14 +65,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, } paddedRepShape[outOrd[0]] += pad; } - // blocked -> shared - if (srcLayout.isa() && - dstLayout.isa()) { - auto sharedLayout = dstLayout.cast(); - for (int v : dstTy.getShape()) - paddedRepShape.push_back(v); - } - return paddedRepShape; } @@ -140,8 +131,9 @@ private: auto dstTy = cvtLayout.result().getType().cast(); auto srcEncoding = srcTy.getEncoding(); auto dstEncoding = dstTy.getEncoding(); - if (srcEncoding.isa()) { - // only block->block and block->shared is supported now + if (srcEncoding.isa() || + dstEncoding.isa()) { + // Only blocked -> blocked conversion requires for scratch allocation return; } // ConvertLayoutOp with both input/output non-shared_layout diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index eec722935..8b9db71a7 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -333,6 +333,13 @@ public: PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit) {} + explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const Allocation *allocation, + Value smem, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + allocation(allocation), smem(smem) {} + Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const { auto llvmIndexTy = this->getTypeConverter()->getIndexType(); auto cast = rewriter.create( @@ -585,12 +592,12 @@ public: return multiDimIdx; } + template Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter, - Value smem, const Allocation *allocation, - Operation *op) const { + T value) const { auto ptrTy = LLVM::LLVMPointerType::get( - this->getTypeConverter()->convertType(rewriter.getIntegerType(8)), 3); - auto bufferId = allocation->getBufferId(op); + this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); + auto bufferId = allocation->getBufferId(value); assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); size_t offset = allocation->getOffset(bufferId); auto llvmIndexTy = this->getTypeConverter()->getIndexType(); @@ -598,6 +605,10 @@ public: Value base = gep(ptrTy, smem, offVal); return base; } + +protected: + const Allocation *allocation; + Value smem; }; // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a @@ -1332,6 +1343,65 @@ struct AddPtrOpConversion } }; +struct AllocTensorOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::gpu::AllocTensorOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::AllocTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value smemBase = getSharedMemoryBase(loc, rewriter, op.getResult()); + auto resultTy = op.getType().dyn_cast(); + auto llvmElemTy = + getTypeConverter()->convertType(resultTy.getElementType()); + auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); + Value resultVal = + rewriter.create(loc, elemPtrTy, smemBase); + rewriter.replaceOp(op, resultVal); + return success(); + } +}; + +struct ExtractSliceOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto srcTy = op.src().getType().dyn_cast(); + auto srcLayout = srcTy.getEncoding().dyn_cast(); + assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion"); + + // axis > 0 will result in non-contiguous memory access if the result tensor + // is an alias of the source tensor. + auto axis = + op->getAttrOfType("axis").cast().getInt(); + assert(axis == 0 && "Only axis=0 is supported for now"); + + // Example: + // %dst = extract_slice %src, %index {axis = 0} + // src.shape = [11, 2, 3, 4, 1] + // offset = %index * 2 * 3 * 4 * 1 + auto dstTy = op.getType().dyn_cast(); + auto base = product(dstTy.getShape()); + auto baseVal = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), base); + Value offset = rewriter.create(loc, adaptor.index(), baseVal); + + auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); + Value resultVal = + rewriter.create(loc, elemPtrTy, adaptor.src(), offset); + rewriter.replaceOp(op, resultVal); + return success(); + } +}; + template class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: @@ -1379,13 +1449,6 @@ public: using ConvertTritonGPUOpToLLVMPattern< triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern; - ConvertLayoutOpConversion(LLVMTypeConverter &converter, - const Allocation *allocation, Value smem, - PatternBenefit benefit) - : ConvertTritonGPUOpToLLVMPattern(converter, - benefit), - allocation(allocation), smem(smem) {} - LogicalResult matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -1399,13 +1462,10 @@ public: if ((!srcLayout.isa()) || (!dstLayout.isa())) { // TODO: not implemented - llvm::errs() - << "convert_layout except for blocked -> blocked is not implemented"; return failure(); } auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); - Value smemBase = - getSharedMemoryBase(loc, rewriter, smem, allocation, op.getOperation()); + Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); smemBase = bit_cast(elemPtrTy, smemBase); @@ -1587,9 +1647,6 @@ private: } } } - - const Allocation *allocation; - Value smem; }; /// ====================== dot codegen begin ========================== @@ -1926,11 +1983,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { NOT_APPLICABLE, }; - explicit DotOpConversion(LLVMTypeConverter &typeConverter, - const Allocation *allocation, Value smem, - PatternBenefit benefit = 1) - : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit), - allocation(allocation), smem(smem) {} + using ConvertTritonGPUOpToLLVMPattern< + triton::DotOp>::ConvertTritonGPUOpToLLVMPattern; LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, @@ -1995,15 +2049,6 @@ private: assert(false && "Not implemented yet."); return failure(); } - - Value getSmemAddr(Value value, Location loc, - ConversionPatternRewriter &rewriter) const { - return getSharedMemoryBase(loc, rewriter, smem, allocation, - value.getDefiningOp()); - } - - const Allocation *allocation; - Value smem; }; struct DotOpConversionHelper { @@ -2340,7 +2385,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, SmallVector ptrs(numPtrs); Type smemPtrTy = helper.getShemPtrTy(); - auto smemBase = getSmemAddr(tensor, loc, rewriter); + auto smemBase = getSharedMemoryBase(loc, rewriter, tensor); for (int i = 0; i < numPtrs; i++) { ptrs[i] = bit_cast( smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]}))); @@ -2479,10 +2524,12 @@ public: SmallVector types(numElementsPerThread, convertType(type.getElementType())); return LLVM::LLVMStructType::getLiteral(&getContext(), types); - } else if (auto mma_layout = layout.dyn_cast()) { - return type; - } else if (auto shared_layout = layout.dyn_cast()) { + } else if (auto mma_layout = layout.dyn_cast_or_null()) { + // TODO: Not implemented return type; + } else if (auto shared_layout = + layout.dyn_cast_or_null()) { + return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3); } return llvm::None; } @@ -2493,6 +2540,9 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, AxisInfoAnalysis &axisInfoAnalysis, const Allocation *allocation, Value smem, PatternBenefit benefit = 1) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, allocation, smem, + benefit); patterns.add(typeConverter, benefit); patterns.add>(typeConverter, benefit); @@ -2503,9 +2553,10 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add>(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); patterns.add(typeConverter, allocation, smem, benefit); + patterns.add(typeConverter, allocation, smem, + benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d6fd864b7..a590e9137 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -431,9 +431,10 @@ mlir::LogicalResult ExtractSliceOp::inferReturnTypes( auto axis = attributes.get("axis").cast().getInt(); if (axis < 0 || axis > srcShape.size()) return failure(); - // Since we only extract a slice from a certain index on the axis, - // the dims before the axis can be dropped. - auto dstShape = srcShape.drop_front(axis + 1); + SmallVector dstShape; + for (int i = 0; i < srcShape.size(); i++) + if (i != axis) + dstShape.push_back(srcShape[i]); auto returnType = RankedTensorType::get(dstShape, srcType.getElementType(), encoding); inferredReturnTypes.assign({returnType}); diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 2fbff865d..19d403dbb 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -22,11 +22,9 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK: scratch offset = 8192, size = 0 - // CHECK-NEXT: offset = 0, size = 8192 + // CHECK: offset = 0, size = 8192 %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> - // CHECK-NEXT: scratch offset = 16384, size = 0 // CHECK-NEXT: offset = 8192, size = 8192 %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> @@ -52,20 +50,16 @@ func @reusable(%A : !tt.ptr) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<32x128x!tt.ptr, #AL> %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK: scratch offset = 8192, size = 0 // CHECK-NEXT: offset = 0, size = 8192 %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> - // CHECK-NEXT: scratch offset = 16384, size = 0 // CHECK-NEXT: offset = 8192, size = 8192 %a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK-NEXT: scratch offset = 24576, size = 0 // CHECK-NEXT: offset = 16384, size = 8192 %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> - // CHECK-NEXT: scratch offset = 8192, size = 0 // CHECK-NEXT: offset = 0, size = 8192 %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> %c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 915cc6c0c..7f5050ec5 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -293,6 +293,44 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- +#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK: llvm.mlir.global internal @global_smem + // CHECK-LABEL: basic_alloc_tensor + func @basic_alloc_tensor() { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: llvm.mlir.constant + // CHECK-NEXT: llvm.getelementptr + // CHECK-NEXT: llvm.bitcast + %0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #shared0> + return + } +} + +// ----- + +#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK: llvm.mlir.global internal @global_smem + // CHECK-LABEL: basic_extract_slice + func @basic_extract_slice() { + // CHECK: %[[BASE0:.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: %[[OFFSET0:.*]] = llvm.mlir.constant + // CHECK-NEXT: %[[OFFSET1:.*]] = llvm.mlir.constant + // CHECK-NEXT: llvm.getelementptr %[[BASE0]][%[[OFFSET1]]] + // CHECK-NEXT: %[[BASE1:.*]] = llvm.bitcast + // CHECK-NEXT: %[[OFFSET2:.*]] = llvm.mlir.constant + // CHECK-NEXT: %[[OFFSET3:.*]] = llvm.mul %[[OFFSET0]], %[[OFFSET2]] + // CHECK-NEXT: llvm.getelementptr %[[BASE1]][%[[OFFSET3]]] + %index = arith.constant 1 : i32 + %0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0> + %1 = triton_gpu.extract_slice %0, %index {axis = 0: i32} : tensor<128x16x32xf32, #shared0> -> tensor<16x32xf32, #shared0> + return + } +} + +// ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: basic_splat