[Triton-MLIR][BACKEND] Refine/add codegen for get_promgram_id and get_num_programs Op (#877)
This commit is contained in:
@@ -2136,13 +2136,43 @@ struct GetProgramIdOpConversion
|
||||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
assert(op.axis() < 3);
|
||||
|
||||
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x);
|
||||
loc, rewriter.getIndexType(), dims[op.axis()]);
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
|
||||
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
|
||||
return success();
|
||||
}
|
||||
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
mlir::gpu::Dimension::z};
|
||||
};
|
||||
|
||||
struct GetNumProgramsOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::GetNumProgramsOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::GetNumProgramsOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
assert(op.axis() < 3);
|
||||
|
||||
Value blockId = rewriter.create<::mlir::gpu::BlockDimOp>(
|
||||
loc, rewriter.getIndexType(), dims[op.axis()]);
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
|
||||
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
|
||||
return success();
|
||||
}
|
||||
|
||||
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
|
||||
mlir::gpu::Dimension::y,
|
||||
mlir::gpu::Dimension::z};
|
||||
};
|
||||
|
||||
struct AddPtrOpConversion
|
||||
@@ -6072,6 +6102,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
||||
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
||||
axisInfoAnalysis, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
|
@@ -849,3 +849,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
%blockidx = tt.get_program_id {axis=0:i32} : i32
|
||||
%blockidy = tt.get_program_id {axis=1:i32} : i32
|
||||
%blockidz = tt.get_program_id {axis=2:i32} : i32
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.y
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.z
|
||||
%v0 = arith.addi %blockidx, %blockidy : i32
|
||||
%v1 = arith.addi %v0, %blockidz : i32
|
||||
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
||||
tt.store %a, %0 : tensor<32xi32, #blocked0>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
// CHECK: nvvm.read.ptx.sreg.ntid.x
|
||||
// CHECK: nvvm.read.ptx.sreg.ntid.y
|
||||
// CHECK: nvvm.read.ptx.sreg.ntid.z
|
||||
%blockdimx = tt.get_num_programs {axis=0:i32} : i32
|
||||
%blockdimy = tt.get_num_programs {axis=1:i32} : i32
|
||||
%blockdimz = tt.get_num_programs {axis=2:i32} : i32
|
||||
%v0 = arith.addi %blockdimx, %blockdimy : i32
|
||||
%v1 = arith.addi %v0, %blockdimz : i32
|
||||
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
||||
tt.store %a, %0 : tensor<32xi32, #blocked0>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user