From d9dd97492f228020573b39a9cec14ee3b8776957 Mon Sep 17 00:00:00 2001 From: daadaada Date: Fri, 25 Feb 2022 08:07:10 +0800 Subject: [PATCH] Use unique_ptr in ir::context_impl (#462) Co-authored-by: Philippe Tillet --- include/triton/ir/context.h | 2 -- include/triton/ir/context_impl.h | 18 ++++++++---------- lib/ir/constant.cc | 20 ++++++++++---------- lib/ir/type.cc | 12 ++++++------ 4 files changed, 24 insertions(+), 28 deletions(-) diff --git a/include/triton/ir/context.h b/include/triton/ir/context.h index 55edf31cd..d824c98b6 100644 --- a/include/triton/ir/context.h +++ b/include/triton/ir/context.h @@ -9,7 +9,6 @@ namespace triton{ namespace ir{ -class builder; class type; class context_impl; @@ -21,7 +20,6 @@ public: context& operator=(const context&) = delete; public: - ir::builder* builder = nullptr; std::shared_ptr p_impl; }; diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index e43b5ad57..081ea249d 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -3,17 +3,15 @@ #ifndef _TRITON_IR_CONTEXT_IMPL_H_ #define _TRITON_IR_CONTEXT_IMPL_H_ -#include #include "triton/ir/type.h" +#include "triton/ir/constant.h" +#include +#include namespace triton{ namespace ir{ class context; -class constant; -class constant_int; -class constant_fp; -class undef_value; /* Context impl */ class context_impl { @@ -30,16 +28,16 @@ public: integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty; // Pointer types - std::map, pointer_type*> ptr_tys; + std::map, std::unique_ptr> ptr_tys; // Block types - std::map, block_type*> block_tys; + std::map, std::unique_ptr> block_tys; // Int constants - std::map, constant_int*> int_constants_; + std::map, std::unique_ptr> int_constants_; // Float constants - std::map, constant_fp*> fp_constants_; + std::map, std::unique_ptr> fp_constants_; // undef values - std::map uv_constants_; + std::map> uv_constants_; }; diff --git a/lib/ir/constant.cc b/lib/ir/constant.cc index b2a50c3be..ab1f6f497 100644 --- a/lib/ir/constant.cc +++ b/lib/ir/constant.cc @@ -47,10 +47,10 @@ constant_int *constant_int::get(type *ty, uint64_t value) { if (!ty->is_integer_ty()) throw std::runtime_error("Cannot create constant_int with non integer ty"); context_impl *impl = ty->get_context().p_impl.get(); - constant_int *& cst = impl->int_constants_[std::make_pair(ty, value)]; - if(cst == nullptr) - cst = new constant_int(ty, value); - return cst; + std::unique_ptr &cst = impl->int_constants_[std::make_pair(ty, value)]; + if(!cst) + cst.reset(new constant_int(ty, value)); + return cst.get(); } @@ -73,10 +73,10 @@ constant *constant_fp::get_zero_value_for_negation(type *ty) { constant *constant_fp::get(type *ty, double v){ context_impl *impl = ty->get_context().p_impl.get(); - constant_fp *&result = impl->fp_constants_[std::make_pair(ty, v)]; + std::unique_ptr &result = impl->fp_constants_[std::make_pair(ty, v)]; if(!result) - result = new constant_fp(ty, v); - return result; + result.reset(new constant_fp(ty, v)); + return result.get(); } @@ -86,10 +86,10 @@ undef_value::undef_value(type *ty) undef_value *undef_value::get(type *ty) { context_impl *impl = ty->get_context().p_impl.get(); - undef_value *&result = impl->uv_constants_[ty]; + std::unique_ptr &result = impl->uv_constants_[ty]; if(!result) - result = new undef_value(ty); - return result; + result.reset(new undef_value(ty)); + return result.get(); } /* global value */ diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 74066a65a..7e4e4e5d7 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -167,10 +167,10 @@ pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){ assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!"); // look-up context_impl *impl = elt_ty->get_context().p_impl.get(); - pointer_type *&entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)]; + std::unique_ptr &entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)]; if(!entry) - entry = new pointer_type(elt_ty, address_space); - return entry; + entry.reset(new pointer_type(elt_ty, address_space)); + return entry.get(); } //===----------------------------------------------------------------------===// @@ -217,10 +217,10 @@ block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) { assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!"); // look-up context_impl *impl = elt_ty->get_context().p_impl.get(); - block_type *&entry = impl->block_tys[std::make_pair(elt_ty, shapes)]; + std::unique_ptr &entry = impl->block_tys[std::make_pair(elt_ty, shapes)]; if(!entry) - entry = new block_type(elt_ty, shapes); - return entry; + entry.reset(new block_type(elt_ty, shapes)); + return entry.get(); } block_type* block_type::get_same_shapes(type *ty, type *ref){