[ir][print] improved pretty-printing of constants and instructions
This commit is contained in:
@@ -21,6 +21,7 @@ protected:
|
||||
public:
|
||||
static constant* get_all_ones_value(type *ty);
|
||||
static constant* get_null_value(type *ty);
|
||||
virtual std::string repr() const = 0;
|
||||
};
|
||||
|
||||
/* Undef value */
|
||||
@@ -30,6 +31,7 @@ private:
|
||||
|
||||
public:
|
||||
static undef_value* get(type* ty);
|
||||
std::string repr() const { return "undef"; }
|
||||
};
|
||||
|
||||
|
||||
@@ -40,8 +42,8 @@ protected:
|
||||
|
||||
public:
|
||||
virtual uint64_t get_value() const { return value_; }
|
||||
virtual std::string repr() const { return std::to_string(get_value()); }
|
||||
static constant_int *get(type *ty, uint64_t value);
|
||||
std::string repr() const { return std::to_string(value_); }
|
||||
|
||||
protected:
|
||||
uint64_t value_;
|
||||
@@ -66,28 +68,6 @@ private:
|
||||
bool has_value_;
|
||||
};
|
||||
|
||||
class constant_expression: public constant_int {
|
||||
typedef binary_op_t op_t;
|
||||
|
||||
private:
|
||||
constant_expression(op_t op, constant_int* lhs, constant_int* rhs);
|
||||
|
||||
public:
|
||||
uint64_t get_value() const;
|
||||
// Wraps
|
||||
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
||||
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
|
||||
// Factory
|
||||
static constant_expression *create(op_t op, constant_int* lhs, constant_int* rhs);
|
||||
|
||||
private:
|
||||
op_t op_;
|
||||
constant_int* lhs_;
|
||||
constant_int* rhs_;
|
||||
bool has_no_unsigned_wrap_;
|
||||
bool has_no_signed_wrap_;
|
||||
};
|
||||
|
||||
/* constant range */
|
||||
class constant_range: public constant{
|
||||
constant_range(type *ty, constant_int* first, constant_int* last);
|
||||
@@ -96,6 +76,7 @@ public:
|
||||
static constant *get(constant_int *first, constant_int *last);
|
||||
const constant_int* get_first() const;
|
||||
const constant_int* get_last() const;
|
||||
std::string repr() const { return first_->repr() + " ... " + last_->repr(); }
|
||||
|
||||
private:
|
||||
constant_int* first_;
|
||||
@@ -112,6 +93,7 @@ public:
|
||||
static constant* get_zero_value_for_negation(type *ty);
|
||||
static constant* get(context &ctx, double v);
|
||||
static constant* get(type *ty, double v);
|
||||
std::string repr() const { return std::to_string(value_); }
|
||||
|
||||
private:
|
||||
double value_;
|
||||
@@ -128,6 +110,7 @@ public:
|
||||
global_value(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage, const std::string &name,
|
||||
unsigned addr_space);
|
||||
std::string repr() const { return get_name(); }
|
||||
|
||||
private:
|
||||
linkage_types_t linkage_;
|
||||
@@ -139,6 +122,8 @@ public:
|
||||
global_object(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage, const std::string &name,
|
||||
unsigned addr_space = 0);
|
||||
std::string repr() const { return get_name(); }
|
||||
|
||||
};
|
||||
|
||||
/* global variable */
|
||||
@@ -146,6 +131,8 @@ class alloc_const: public global_object {
|
||||
public:
|
||||
alloc_const(type *ty, constant_int *size,
|
||||
const std::string &name = "");
|
||||
std::string repr() const { return get_name(); }
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -11,7 +11,6 @@ namespace ir{
|
||||
|
||||
class context;
|
||||
class constant;
|
||||
class constant_expression;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class undef_value;
|
||||
@@ -39,8 +38,6 @@ public:
|
||||
std::map<type*, undef_value*> uv_constants_;
|
||||
// Metaparameters
|
||||
std::vector<metaparameter*> mp_constants_;
|
||||
// Expr constants
|
||||
std::map<std::tuple<int, constant*, constant*>, constant_expression*> expr_constants_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -309,7 +309,7 @@ public:
|
||||
// ternary
|
||||
class ternary_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "ternary"; }
|
||||
std::string repr_impl() const { return "cond"; }
|
||||
ternary_inst(value *cond, value *true_value, value *false_value,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
@@ -438,7 +438,6 @@ public:
|
||||
class retile_inst: public unary_inst {
|
||||
protected:
|
||||
retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
|
||||
static std::string shape_suffix(ir::type* ty);
|
||||
};
|
||||
|
||||
// reshape
|
||||
@@ -446,7 +445,7 @@ protected:
|
||||
class reshape_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "reshape" + shape_suffix(get_type()); }
|
||||
std::string repr_impl() const { return "reshape"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
@@ -458,7 +457,7 @@ public:
|
||||
class splat_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "splat" + shape_suffix(get_type()); }
|
||||
std::string repr_impl() const { return "splat"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
@@ -470,7 +469,7 @@ public:
|
||||
class broadcast_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "broadcast" + shape_suffix(get_type()); }
|
||||
std::string repr_impl() const { return "broadcast"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
@@ -688,6 +687,7 @@ private:
|
||||
public:
|
||||
static nv_static_program_idx *get(constant_range* range);
|
||||
constant_range* get_range() const;
|
||||
std::string repr() const { return get_name(); }
|
||||
|
||||
private:
|
||||
constant_range *range_;
|
||||
|
@@ -3,7 +3,9 @@
|
||||
#ifndef _TRITON_IR_TYPE_H_
|
||||
#define _TRITON_IR_TYPE_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
@@ -102,6 +104,42 @@ public:
|
||||
static integer_type *get_int64_ty(context &ctx);
|
||||
static integer_type *get_int128_ty(context &ctx);
|
||||
|
||||
// repr
|
||||
std::string tile_repr() const {
|
||||
std::string res = get_tile_element_ty()->repr();
|
||||
auto shapes = get_tile_shapes();
|
||||
res += "<";
|
||||
for(size_t i = 0; i < shapes.size(); i++){
|
||||
if(i > 0)
|
||||
res += ", ";
|
||||
res += std::to_string(shapes[i]);
|
||||
}
|
||||
res+= ">";
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string repr() const {
|
||||
switch(id_) {
|
||||
case VoidTyID: return "void";
|
||||
case HalfTyID: return "f16";
|
||||
case FloatTyID: return "f32";
|
||||
case DoubleTyID: return "f64";
|
||||
case X86_FP80TyID: return "f80";
|
||||
case FP128TyID: return "f128";
|
||||
case PPC_FP128TyID: return "ppcf128";
|
||||
case LabelTyID: return "label";
|
||||
case MetadataTyID: return "md";
|
||||
case TokenTyID: return "tok";
|
||||
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth());
|
||||
case FunctionTyID: return "fn";
|
||||
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||
case StructTyID: return "struct";
|
||||
case TileTyID: return tile_repr();
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return "";
|
||||
};
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
|
@@ -148,20 +148,10 @@ DEFINE_UNARY_FLOAT(fneg)
|
||||
value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
|
||||
value *rhs, const std::string &name,
|
||||
bool has_nuw, bool has_nsw) {
|
||||
auto *clhs = dynamic_cast<constant_int*>(lhs);
|
||||
auto *crhs = dynamic_cast<constant_int*>(rhs);
|
||||
if(clhs && crhs){
|
||||
constant_expression* result = constant_expression::create(op, clhs, crhs);
|
||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||
if (has_nsw) result->set_has_no_signed_wrap();
|
||||
return result;
|
||||
}
|
||||
else {
|
||||
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);
|
||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||
if (has_nsw) result->set_has_no_signed_wrap();
|
||||
return result;
|
||||
}
|
||||
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);
|
||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||
if (has_nsw) result->set_has_no_signed_wrap();
|
||||
return result;
|
||||
}
|
||||
|
||||
#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
|
||||
|
@@ -120,41 +120,6 @@ metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector<u
|
||||
return result;
|
||||
}
|
||||
|
||||
// constant expression
|
||||
constant_expression::constant_expression(op_t op, constant_int* lhs, constant_int* rhs)
|
||||
: constant_int(lhs->get_type(), 0),
|
||||
op_(op), lhs_(lhs), rhs_(rhs) { }
|
||||
|
||||
|
||||
constant_expression *constant_expression::create(op_t op, constant_int* lhs, constant_int* rhs) {
|
||||
context_impl *impl = lhs->get_type()->get_context().p_impl.get();
|
||||
constant_expression *& result = impl->expr_constants_[std::make_tuple((int)op, lhs, rhs)];
|
||||
if(!result)
|
||||
result = new constant_expression(op, lhs, rhs);
|
||||
return result;
|
||||
}
|
||||
|
||||
uint64_t constant_expression::get_value() const {
|
||||
uint64_t lhs = lhs_->get_value();
|
||||
uint64_t rhs = rhs_->get_value();
|
||||
switch(op_) {
|
||||
case op_t::Add : return lhs + rhs;
|
||||
case op_t::Sub : return lhs - rhs;
|
||||
case op_t::Mul : return lhs * rhs;
|
||||
case op_t::UDiv : return lhs / rhs;
|
||||
case op_t::SDiv : return lhs / rhs;
|
||||
case op_t::URem : return lhs % rhs;
|
||||
case op_t::SRem : return lhs % rhs;
|
||||
case op_t::Shl : return lhs << rhs;
|
||||
case op_t::LShr : return lhs >> rhs;
|
||||
case op_t::AShr : return lhs >> rhs;
|
||||
case op_t::And : return lhs && rhs;
|
||||
case op_t::Or : return lhs || rhs;
|
||||
case op_t::Xor : return lhs ^ rhs;
|
||||
default: throw std::runtime_error("unsupported constexpr binary operator");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// undef value
|
||||
undef_value::undef_value(type *ty)
|
||||
|
@@ -482,18 +482,6 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string retile_inst::shape_suffix(ir::type* ty){
|
||||
std::string res = "[";
|
||||
const auto& shapes = ty->get_tile_shapes();
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
res += std::to_string(ty->get_tile_shapes()[i]);
|
||||
if(i < shapes.size() - 1)
|
||||
res += ", ";
|
||||
}
|
||||
res += "]";
|
||||
return res;
|
||||
}
|
||||
|
||||
retile_inst::retile_inst(value *arg, const type::tile_shapes_t &shapes,
|
||||
const std::string &name, instruction *next)
|
||||
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { }
|
||||
|
@@ -44,14 +44,15 @@ void print(module &mod, std::ostream& os) {
|
||||
else
|
||||
os << " = ";
|
||||
}
|
||||
os << inst->repr();
|
||||
ir::type* type = inst->get_type();
|
||||
os << inst->repr() << " " << type->repr();
|
||||
ir::instruction::ops_t ops = inst->ops();
|
||||
size_t num_ops = inst->get_num_operands();
|
||||
if(num_ops > 0)
|
||||
os << " ";;
|
||||
for(unsigned i = 0; i < num_ops; i++){
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(ops[i]))
|
||||
os << x->get_value();
|
||||
if(auto *x = dynamic_cast<ir::constant*>(ops[i]))
|
||||
os << x->repr();
|
||||
else
|
||||
os << get_name(ops[i], cnt++);
|
||||
os << (i < num_ops - 1?", ":"");
|
||||
|
Reference in New Issue
Block a user