[Triton-MLIR][BACKEND] Refine/add codegen for get_promgram_id and get_num_programs Op (#877)

This commit is contained in:
Yan Chunwei
2022-11-15 15:45:24 +08:00
committed by GitHub
parent 4c4159c6fa
commit a22ff39017
2 changed files with 74 additions and 1 deletions

View File

@@ -2136,13 +2136,43 @@ struct GetProgramIdOpConversion
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
assert(op.axis() < 3);
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x);
loc, rewriter.getIndexType(), dims[op.axis()]);
auto llvmIndexTy = getTypeConverter()->getIndexType();
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
return success();
}
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
};
struct GetNumProgramsOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::GetNumProgramsOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::GetNumProgramsOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
assert(op.axis() < 3);
Value blockId = rewriter.create<::mlir::gpu::BlockDimOp>(
loc, rewriter.getIndexType(), dims[op.axis()]);
auto llvmIndexTy = getTypeConverter()->getIndexType();
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
return success();
}
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
mlir::gpu::Dimension::y,
mlir::gpu::Dimension::z};
};
struct AddPtrOpConversion
@@ -6072,6 +6102,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit);
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);

View File

@@ -849,3 +849,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
%blockidx = tt.get_program_id {axis=0:i32} : i32
%blockidy = tt.get_program_id {axis=1:i32} : i32
%blockidz = tt.get_program_id {axis=2:i32} : i32
// CHECK: nvvm.read.ptx.sreg.ctaid.x
// CHECK: nvvm.read.ptx.sreg.ctaid.y
// CHECK: nvvm.read.ptx.sreg.ctaid.z
%v0 = arith.addi %blockidx, %blockidy : i32
%v1 = arith.addi %v0, %blockidz : i32
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
tt.store %a, %0 : tensor<32xi32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.ntid.x
// CHECK: nvvm.read.ptx.sreg.ntid.y
// CHECK: nvvm.read.ptx.sreg.ntid.z
%blockdimx = tt.get_num_programs {axis=0:i32} : i32
%blockdimy = tt.get_num_programs {axis=1:i32} : i32
%blockdimz = tt.get_num_programs {axis=2:i32} : i32
%v0 = arith.addi %blockdimx, %blockdimy : i32
%v1 = arith.addi %v0, %blockdimz : i32
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
tt.store %a, %0 : tensor<32xi32, #blocked0>
return
}
}