Use unique_ptr in ir::context_impl (#462)

Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
This commit is contained in:
daadaada
2022-02-25 08:07:10 +08:00
committed by GitHub
parent 98ed7db8c1
commit d9dd97492f
4 changed files with 24 additions and 28 deletions

View File

@@ -9,7 +9,6 @@
namespace triton{ namespace triton{
namespace ir{ namespace ir{
class builder;
class type; class type;
class context_impl; class context_impl;
@@ -21,7 +20,6 @@ public:
context& operator=(const context&) = delete; context& operator=(const context&) = delete;
public: public:
ir::builder* builder = nullptr;
std::shared_ptr<context_impl> p_impl; std::shared_ptr<context_impl> p_impl;
}; };

View File

@@ -3,17 +3,15 @@
#ifndef _TRITON_IR_CONTEXT_IMPL_H_ #ifndef _TRITON_IR_CONTEXT_IMPL_H_
#define _TRITON_IR_CONTEXT_IMPL_H_ #define _TRITON_IR_CONTEXT_IMPL_H_
#include <map>
#include "triton/ir/type.h" #include "triton/ir/type.h"
#include "triton/ir/constant.h"
#include <map>
#include <memory>
namespace triton{ namespace triton{
namespace ir{ namespace ir{
class context; class context;
class constant;
class constant_int;
class constant_fp;
class undef_value;
/* Context impl */ /* Context impl */
class context_impl { class context_impl {
@@ -30,16 +28,16 @@ public:
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty; integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty;
// Pointer types // Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys; std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
// Block types // Block types
std::map<std::pair<type*, type::block_shapes_t>, block_type*> block_tys; std::map<std::pair<type*, type::block_shapes_t>, std::unique_ptr<block_type>> block_tys;
// Int constants // Int constants
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_; std::map<std::pair<type*, uint64_t>, std::unique_ptr<constant_int>> int_constants_;
// Float constants // Float constants
std::map<std::pair<type*, double>, constant_fp*> fp_constants_; std::map<std::pair<type*, double>, std::unique_ptr<constant_fp>> fp_constants_;
// undef values // undef values
std::map<type*, undef_value*> uv_constants_; std::map<type*, std::unique_ptr<undef_value>> uv_constants_;
}; };

View File

@@ -47,10 +47,10 @@ constant_int *constant_int::get(type *ty, uint64_t value) {
if (!ty->is_integer_ty()) if (!ty->is_integer_ty())
throw std::runtime_error("Cannot create constant_int with non integer ty"); throw std::runtime_error("Cannot create constant_int with non integer ty");
context_impl *impl = ty->get_context().p_impl.get(); context_impl *impl = ty->get_context().p_impl.get();
constant_int *& cst = impl->int_constants_[std::make_pair(ty, value)]; std::unique_ptr<constant_int> &cst = impl->int_constants_[std::make_pair(ty, value)];
if(cst == nullptr) if(!cst)
cst = new constant_int(ty, value); cst.reset(new constant_int(ty, value));
return cst; return cst.get();
} }
@@ -73,10 +73,10 @@ constant *constant_fp::get_zero_value_for_negation(type *ty) {
constant *constant_fp::get(type *ty, double v){ constant *constant_fp::get(type *ty, double v){
context_impl *impl = ty->get_context().p_impl.get(); context_impl *impl = ty->get_context().p_impl.get();
constant_fp *&result = impl->fp_constants_[std::make_pair(ty, v)]; std::unique_ptr<constant_fp> &result = impl->fp_constants_[std::make_pair(ty, v)];
if(!result) if(!result)
result = new constant_fp(ty, v); result.reset(new constant_fp(ty, v));
return result; return result.get();
} }
@@ -86,10 +86,10 @@ undef_value::undef_value(type *ty)
undef_value *undef_value::get(type *ty) { undef_value *undef_value::get(type *ty) {
context_impl *impl = ty->get_context().p_impl.get(); context_impl *impl = ty->get_context().p_impl.get();
undef_value *&result = impl->uv_constants_[ty]; std::unique_ptr<undef_value> &result = impl->uv_constants_[ty];
if(!result) if(!result)
result = new undef_value(ty); result.reset(new undef_value(ty));
return result; return result.get();
} }
/* global value */ /* global value */

View File

@@ -167,10 +167,10 @@ pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!"); assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
// look-up // look-up
context_impl *impl = elt_ty->get_context().p_impl.get(); context_impl *impl = elt_ty->get_context().p_impl.get();
pointer_type *&entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)]; std::unique_ptr<pointer_type> &entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)];
if(!entry) if(!entry)
entry = new pointer_type(elt_ty, address_space); entry.reset(new pointer_type(elt_ty, address_space));
return entry; return entry.get();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -217,10 +217,10 @@ block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) {
assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!"); assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!");
// look-up // look-up
context_impl *impl = elt_ty->get_context().p_impl.get(); context_impl *impl = elt_ty->get_context().p_impl.get();
block_type *&entry = impl->block_tys[std::make_pair(elt_ty, shapes)]; std::unique_ptr<block_type> &entry = impl->block_tys[std::make_pair(elt_ty, shapes)];
if(!entry) if(!entry)
entry = new block_type(elt_ty, shapes); entry.reset(new block_type(elt_ty, shapes));
return entry; return entry.get();
} }
block_type* block_type::get_same_shapes(type *ty, type *ref){ block_type* block_type::get_same_shapes(type *ty, type *ref){