diff --git a/include/triton/ir/constant.h b/include/triton/ir/constant.h index 6e177a47f..23be73256 100644 --- a/include/triton/ir/constant.h +++ b/include/triton/ir/constant.h @@ -21,6 +21,7 @@ protected: public: static constant* get_all_ones_value(type *ty); static constant* get_null_value(type *ty); + virtual std::string repr() const = 0; }; /* Undef value */ @@ -30,6 +31,7 @@ private: public: static undef_value* get(type* ty); + std::string repr() const { return "undef"; } }; @@ -40,8 +42,8 @@ protected: public: virtual uint64_t get_value() const { return value_; } - virtual std::string repr() const { return std::to_string(get_value()); } static constant_int *get(type *ty, uint64_t value); + std::string repr() const { return std::to_string(value_); } protected: uint64_t value_; @@ -66,28 +68,6 @@ private: bool has_value_; }; -class constant_expression: public constant_int { - typedef binary_op_t op_t; - -private: - constant_expression(op_t op, constant_int* lhs, constant_int* rhs); - -public: - uint64_t get_value() const; - // Wraps - void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; } - void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; } - // Factory - static constant_expression *create(op_t op, constant_int* lhs, constant_int* rhs); - -private: - op_t op_; - constant_int* lhs_; - constant_int* rhs_; - bool has_no_unsigned_wrap_; - bool has_no_signed_wrap_; -}; - /* constant range */ class constant_range: public constant{ constant_range(type *ty, constant_int* first, constant_int* last); @@ -96,6 +76,7 @@ public: static constant *get(constant_int *first, constant_int *last); const constant_int* get_first() const; const constant_int* get_last() const; + std::string repr() const { return first_->repr() + " ... " + last_->repr(); } private: constant_int* first_; @@ -112,6 +93,7 @@ public: static constant* get_zero_value_for_negation(type *ty); static constant* get(context &ctx, double v); static constant* get(type *ty, double v); + std::string repr() const { return std::to_string(value_); } private: double value_; @@ -128,6 +110,7 @@ public: global_value(type *ty, unsigned num_ops, linkage_types_t linkage, const std::string &name, unsigned addr_space); + std::string repr() const { return get_name(); } private: linkage_types_t linkage_; @@ -139,6 +122,8 @@ public: global_object(type *ty, unsigned num_ops, linkage_types_t linkage, const std::string &name, unsigned addr_space = 0); + std::string repr() const { return get_name(); } + }; /* global variable */ @@ -146,6 +131,8 @@ class alloc_const: public global_object { public: alloc_const(type *ty, constant_int *size, const std::string &name = ""); + std::string repr() const { return get_name(); } + }; } diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index df26796c6..5995de0d4 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -11,7 +11,6 @@ namespace ir{ class context; class constant; -class constant_expression; class constant_int; class constant_fp; class undef_value; @@ -39,8 +38,6 @@ public: std::map uv_constants_; // Metaparameters std::vector mp_constants_; - // Expr constants - std::map, constant_expression*> expr_constants_; }; } diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 5c6af5362..f0a345c81 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -309,7 +309,7 @@ public: // ternary class ternary_inst: public instruction { private: - std::string repr_impl() const { return "ternary"; } + std::string repr_impl() const { return "cond"; } ternary_inst(value *cond, value *true_value, value *false_value, const std::string &name, instruction *next); @@ -438,7 +438,6 @@ public: class retile_inst: public unary_inst { protected: retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next); - static std::string shape_suffix(ir::type* ty); }; // reshape @@ -446,7 +445,7 @@ protected: class reshape_inst: public retile_inst { private: using retile_inst::retile_inst; - std::string repr_impl() const { return "reshape" + shape_suffix(get_type()); } + std::string repr_impl() const { return "reshape"; } public: static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, @@ -458,7 +457,7 @@ public: class splat_inst: public retile_inst { private: using retile_inst::retile_inst; - std::string repr_impl() const { return "splat" + shape_suffix(get_type()); } + std::string repr_impl() const { return "splat"; } public: static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, @@ -470,7 +469,7 @@ public: class broadcast_inst: public retile_inst { private: using retile_inst::retile_inst; - std::string repr_impl() const { return "broadcast" + shape_suffix(get_type()); } + std::string repr_impl() const { return "broadcast"; } public: static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, @@ -688,6 +687,7 @@ private: public: static nv_static_program_idx *get(constant_range* range); constant_range* get_range() const; + std::string repr() const { return get_name(); } private: constant_range *range_; diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 60d2d9691..aee2ecc42 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -3,7 +3,9 @@ #ifndef _TRITON_IR_TYPE_H_ #define _TRITON_IR_TYPE_H_ +#include #include +#include namespace triton{ namespace ir{ @@ -102,6 +104,42 @@ public: static integer_type *get_int64_ty(context &ctx); static integer_type *get_int128_ty(context &ctx); + // repr + std::string tile_repr() const { + std::string res = get_tile_element_ty()->repr(); + auto shapes = get_tile_shapes(); + res += "<"; + for(size_t i = 0; i < shapes.size(); i++){ + if(i > 0) + res += ", "; + res += std::to_string(shapes[i]); + } + res+= ">"; + return res; + } + + std::string repr() const { + switch(id_) { + case VoidTyID: return "void"; + case HalfTyID: return "f16"; + case FloatTyID: return "f32"; + case DoubleTyID: return "f64"; + case X86_FP80TyID: return "f80"; + case FP128TyID: return "f128"; + case PPC_FP128TyID: return "ppcf128"; + case LabelTyID: return "label"; + case MetadataTyID: return "md"; + case TokenTyID: return "tok"; + case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth()); + case FunctionTyID: return "fn"; + case PointerTyID: return get_pointer_element_ty()->repr() + "*"; + case StructTyID: return "struct"; + case TileTyID: return tile_repr(); + default: break; + } + assert(false); + return ""; + }; private: context &ctx_; diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index b4ff3c5b2..458365a60 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -148,20 +148,10 @@ DEFINE_UNARY_FLOAT(fneg) value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw) { - auto *clhs = dynamic_cast(lhs); - auto *crhs = dynamic_cast(rhs); - if(clhs && crhs){ - constant_expression* result = constant_expression::create(op, clhs, crhs); - if (has_nuw) result->set_has_no_unsigned_wrap(); - if (has_nsw) result->set_has_no_signed_wrap(); - return result; - } - else { - binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name); - if (has_nuw) result->set_has_no_unsigned_wrap(); - if (has_nsw) result->set_has_no_signed_wrap(); - return result; - } + binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name); + if (has_nuw) result->set_has_no_unsigned_wrap(); + if (has_nsw) result->set_has_no_signed_wrap(); + return result; } #define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\ diff --git a/lib/ir/constant.cc b/lib/ir/constant.cc index 6df5a58cd..9ff8d6e72 100644 --- a/lib/ir/constant.cc +++ b/lib/ir/constant.cc @@ -120,41 +120,6 @@ metaparameter* metaparameter::create(context &ctx, type *ty, const std::vectorget_type(), 0), - op_(op), lhs_(lhs), rhs_(rhs) { } - - -constant_expression *constant_expression::create(op_t op, constant_int* lhs, constant_int* rhs) { - context_impl *impl = lhs->get_type()->get_context().p_impl.get(); - constant_expression *& result = impl->expr_constants_[std::make_tuple((int)op, lhs, rhs)]; - if(!result) - result = new constant_expression(op, lhs, rhs); - return result; -} - -uint64_t constant_expression::get_value() const { - uint64_t lhs = lhs_->get_value(); - uint64_t rhs = rhs_->get_value(); - switch(op_) { - case op_t::Add : return lhs + rhs; - case op_t::Sub : return lhs - rhs; - case op_t::Mul : return lhs * rhs; - case op_t::UDiv : return lhs / rhs; - case op_t::SDiv : return lhs / rhs; - case op_t::URem : return lhs % rhs; - case op_t::SRem : return lhs % rhs; - case op_t::Shl : return lhs << rhs; - case op_t::LShr : return lhs >> rhs; - case op_t::AShr : return lhs >> rhs; - case op_t::And : return lhs && rhs; - case op_t::Or : return lhs || rhs; - case op_t::Xor : return lhs ^ rhs; - default: throw std::runtime_error("unsupported constexpr binary operator"); - } -} - // undef value undef_value::undef_value(type *ty) diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 3d911b967..e7e5de1f2 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -482,18 +482,6 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask // retile_inst classes //===----------------------------------------------------------------------===// -std::string retile_inst::shape_suffix(ir::type* ty){ - std::string res = "["; - const auto& shapes = ty->get_tile_shapes(); - for(unsigned i = 0; i < shapes.size(); i++){ - res += std::to_string(ty->get_tile_shapes()[i]); - if(i < shapes.size() - 1) - res += ", "; - } - res += "]"; - return res; -} - retile_inst::retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next) : unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { } diff --git a/lib/ir/print.cc b/lib/ir/print.cc index 9b66305b8..31cc15d9a 100644 --- a/lib/ir/print.cc +++ b/lib/ir/print.cc @@ -44,14 +44,15 @@ void print(module &mod, std::ostream& os) { else os << " = "; } - os << inst->repr(); + ir::type* type = inst->get_type(); + os << inst->repr() << " " << type->repr(); ir::instruction::ops_t ops = inst->ops(); size_t num_ops = inst->get_num_operands(); if(num_ops > 0) os << " ";; for(unsigned i = 0; i < num_ops; i++){ - if(auto *x = dynamic_cast(ops[i])) - os << x->get_value(); + if(auto *x = dynamic_cast(ops[i])) + os << x->repr(); else os << get_name(ops[i], cnt++); os << (i < num_ops - 1?", ":"");