diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index d31b3df3c..90abf61c6 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2136,13 +2136,43 @@ struct GetProgramIdOpConversion matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + assert(op.axis() < 3); + Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>( - loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x); + loc, rewriter.getIndexType(), dims[op.axis()]); auto llvmIndexTy = getTypeConverter()->getIndexType(); rewriter.replaceOpWithNewOp( op, TypeRange{llvmIndexTy}, ValueRange{blockId}); return success(); } + + static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, + mlir::gpu::Dimension::y, + mlir::gpu::Dimension::z}; +}; + +struct GetNumProgramsOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::GetNumProgramsOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + assert(op.axis() < 3); + + Value blockId = rewriter.create<::mlir::gpu::BlockDimOp>( + loc, rewriter.getIndexType(), dims[op.axis()]); + auto llvmIndexTy = getTypeConverter()->getIndexType(); + rewriter.replaceOpWithNewOp( + op, TypeRange{llvmIndexTy}, ValueRange{blockId}); + return success(); + } + + static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, + mlir::gpu::Dimension::y, + mlir::gpu::Dimension::z}; }; struct AddPtrOpConversion @@ -6072,6 +6102,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, allocation, smem, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e4095671b..cba6082ec 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -849,3 +849,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { return } } + +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + +func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { + %blockidx = tt.get_program_id {axis=0:i32} : i32 + %blockidy = tt.get_program_id {axis=1:i32} : i32 + %blockidz = tt.get_program_id {axis=2:i32} : i32 + // CHECK: nvvm.read.ptx.sreg.ctaid.x + // CHECK: nvvm.read.ptx.sreg.ctaid.y + // CHECK: nvvm.read.ptx.sreg.ctaid.z + %v0 = arith.addi %blockidx, %blockidy : i32 + %v1 = arith.addi %v0, %blockidz : i32 + %0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0> + tt.store %a, %0 : tensor<32xi32, #blocked0> + + return +} + +} + +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + +func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { + // CHECK: nvvm.read.ptx.sreg.ntid.x + // CHECK: nvvm.read.ptx.sreg.ntid.y + // CHECK: nvvm.read.ptx.sreg.ntid.z + %blockdimx = tt.get_num_programs {axis=0:i32} : i32 + %blockdimy = tt.get_num_programs {axis=1:i32} : i32 + %blockdimz = tt.get_num_programs {axis=2:i32} : i32 + %v0 = arith.addi %blockdimx, %blockdimy : i32 + %v1 = arith.addi %v0, %blockdimz : i32 + %0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0> + tt.store %a, %0 : tensor<32xi32, #blocked0> + + return +} + +}