[ir][print] improved pretty-printing of constants and instructions
This commit is contained in:
@@ -21,6 +21,7 @@ protected:
|
|||||||
public:
|
public:
|
||||||
static constant* get_all_ones_value(type *ty);
|
static constant* get_all_ones_value(type *ty);
|
||||||
static constant* get_null_value(type *ty);
|
static constant* get_null_value(type *ty);
|
||||||
|
virtual std::string repr() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Undef value */
|
/* Undef value */
|
||||||
@@ -30,6 +31,7 @@ private:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
static undef_value* get(type* ty);
|
static undef_value* get(type* ty);
|
||||||
|
std::string repr() const { return "undef"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@@ -40,8 +42,8 @@ protected:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
virtual uint64_t get_value() const { return value_; }
|
virtual uint64_t get_value() const { return value_; }
|
||||||
virtual std::string repr() const { return std::to_string(get_value()); }
|
|
||||||
static constant_int *get(type *ty, uint64_t value);
|
static constant_int *get(type *ty, uint64_t value);
|
||||||
|
std::string repr() const { return std::to_string(value_); }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
uint64_t value_;
|
uint64_t value_;
|
||||||
@@ -66,28 +68,6 @@ private:
|
|||||||
bool has_value_;
|
bool has_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class constant_expression: public constant_int {
|
|
||||||
typedef binary_op_t op_t;
|
|
||||||
|
|
||||||
private:
|
|
||||||
constant_expression(op_t op, constant_int* lhs, constant_int* rhs);
|
|
||||||
|
|
||||||
public:
|
|
||||||
uint64_t get_value() const;
|
|
||||||
// Wraps
|
|
||||||
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
|
||||||
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
|
|
||||||
// Factory
|
|
||||||
static constant_expression *create(op_t op, constant_int* lhs, constant_int* rhs);
|
|
||||||
|
|
||||||
private:
|
|
||||||
op_t op_;
|
|
||||||
constant_int* lhs_;
|
|
||||||
constant_int* rhs_;
|
|
||||||
bool has_no_unsigned_wrap_;
|
|
||||||
bool has_no_signed_wrap_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/* constant range */
|
/* constant range */
|
||||||
class constant_range: public constant{
|
class constant_range: public constant{
|
||||||
constant_range(type *ty, constant_int* first, constant_int* last);
|
constant_range(type *ty, constant_int* first, constant_int* last);
|
||||||
@@ -96,6 +76,7 @@ public:
|
|||||||
static constant *get(constant_int *first, constant_int *last);
|
static constant *get(constant_int *first, constant_int *last);
|
||||||
const constant_int* get_first() const;
|
const constant_int* get_first() const;
|
||||||
const constant_int* get_last() const;
|
const constant_int* get_last() const;
|
||||||
|
std::string repr() const { return first_->repr() + " ... " + last_->repr(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
constant_int* first_;
|
constant_int* first_;
|
||||||
@@ -112,6 +93,7 @@ public:
|
|||||||
static constant* get_zero_value_for_negation(type *ty);
|
static constant* get_zero_value_for_negation(type *ty);
|
||||||
static constant* get(context &ctx, double v);
|
static constant* get(context &ctx, double v);
|
||||||
static constant* get(type *ty, double v);
|
static constant* get(type *ty, double v);
|
||||||
|
std::string repr() const { return std::to_string(value_); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
double value_;
|
double value_;
|
||||||
@@ -128,6 +110,7 @@ public:
|
|||||||
global_value(type *ty, unsigned num_ops,
|
global_value(type *ty, unsigned num_ops,
|
||||||
linkage_types_t linkage, const std::string &name,
|
linkage_types_t linkage, const std::string &name,
|
||||||
unsigned addr_space);
|
unsigned addr_space);
|
||||||
|
std::string repr() const { return get_name(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
linkage_types_t linkage_;
|
linkage_types_t linkage_;
|
||||||
@@ -139,6 +122,8 @@ public:
|
|||||||
global_object(type *ty, unsigned num_ops,
|
global_object(type *ty, unsigned num_ops,
|
||||||
linkage_types_t linkage, const std::string &name,
|
linkage_types_t linkage, const std::string &name,
|
||||||
unsigned addr_space = 0);
|
unsigned addr_space = 0);
|
||||||
|
std::string repr() const { return get_name(); }
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/* global variable */
|
/* global variable */
|
||||||
@@ -146,6 +131,8 @@ class alloc_const: public global_object {
|
|||||||
public:
|
public:
|
||||||
alloc_const(type *ty, constant_int *size,
|
alloc_const(type *ty, constant_int *size,
|
||||||
const std::string &name = "");
|
const std::string &name = "");
|
||||||
|
std::string repr() const { return get_name(); }
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -11,7 +11,6 @@ namespace ir{
|
|||||||
|
|
||||||
class context;
|
class context;
|
||||||
class constant;
|
class constant;
|
||||||
class constant_expression;
|
|
||||||
class constant_int;
|
class constant_int;
|
||||||
class constant_fp;
|
class constant_fp;
|
||||||
class undef_value;
|
class undef_value;
|
||||||
@@ -39,8 +38,6 @@ public:
|
|||||||
std::map<type*, undef_value*> uv_constants_;
|
std::map<type*, undef_value*> uv_constants_;
|
||||||
// Metaparameters
|
// Metaparameters
|
||||||
std::vector<metaparameter*> mp_constants_;
|
std::vector<metaparameter*> mp_constants_;
|
||||||
// Expr constants
|
|
||||||
std::map<std::tuple<int, constant*, constant*>, constant_expression*> expr_constants_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -309,7 +309,7 @@ public:
|
|||||||
// ternary
|
// ternary
|
||||||
class ternary_inst: public instruction {
|
class ternary_inst: public instruction {
|
||||||
private:
|
private:
|
||||||
std::string repr_impl() const { return "ternary"; }
|
std::string repr_impl() const { return "cond"; }
|
||||||
ternary_inst(value *cond, value *true_value, value *false_value,
|
ternary_inst(value *cond, value *true_value, value *false_value,
|
||||||
const std::string &name, instruction *next);
|
const std::string &name, instruction *next);
|
||||||
|
|
||||||
@@ -438,7 +438,6 @@ public:
|
|||||||
class retile_inst: public unary_inst {
|
class retile_inst: public unary_inst {
|
||||||
protected:
|
protected:
|
||||||
retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
|
retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
|
||||||
static std::string shape_suffix(ir::type* ty);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// reshape
|
// reshape
|
||||||
@@ -446,7 +445,7 @@ protected:
|
|||||||
class reshape_inst: public retile_inst {
|
class reshape_inst: public retile_inst {
|
||||||
private:
|
private:
|
||||||
using retile_inst::retile_inst;
|
using retile_inst::retile_inst;
|
||||||
std::string repr_impl() const { return "reshape" + shape_suffix(get_type()); }
|
std::string repr_impl() const { return "reshape"; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||||
@@ -458,7 +457,7 @@ public:
|
|||||||
class splat_inst: public retile_inst {
|
class splat_inst: public retile_inst {
|
||||||
private:
|
private:
|
||||||
using retile_inst::retile_inst;
|
using retile_inst::retile_inst;
|
||||||
std::string repr_impl() const { return "splat" + shape_suffix(get_type()); }
|
std::string repr_impl() const { return "splat"; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||||
@@ -470,7 +469,7 @@ public:
|
|||||||
class broadcast_inst: public retile_inst {
|
class broadcast_inst: public retile_inst {
|
||||||
private:
|
private:
|
||||||
using retile_inst::retile_inst;
|
using retile_inst::retile_inst;
|
||||||
std::string repr_impl() const { return "broadcast" + shape_suffix(get_type()); }
|
std::string repr_impl() const { return "broadcast"; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||||
@@ -688,6 +687,7 @@ private:
|
|||||||
public:
|
public:
|
||||||
static nv_static_program_idx *get(constant_range* range);
|
static nv_static_program_idx *get(constant_range* range);
|
||||||
constant_range* get_range() const;
|
constant_range* get_range() const;
|
||||||
|
std::string repr() const { return get_name(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
constant_range *range_;
|
constant_range *range_;
|
||||||
|
@@ -3,7 +3,9 @@
|
|||||||
#ifndef _TRITON_IR_TYPE_H_
|
#ifndef _TRITON_IR_TYPE_H_
|
||||||
#define _TRITON_IR_TYPE_H_
|
#define _TRITON_IR_TYPE_H_
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
@@ -102,6 +104,42 @@ public:
|
|||||||
static integer_type *get_int64_ty(context &ctx);
|
static integer_type *get_int64_ty(context &ctx);
|
||||||
static integer_type *get_int128_ty(context &ctx);
|
static integer_type *get_int128_ty(context &ctx);
|
||||||
|
|
||||||
|
// repr
|
||||||
|
std::string tile_repr() const {
|
||||||
|
std::string res = get_tile_element_ty()->repr();
|
||||||
|
auto shapes = get_tile_shapes();
|
||||||
|
res += "<";
|
||||||
|
for(size_t i = 0; i < shapes.size(); i++){
|
||||||
|
if(i > 0)
|
||||||
|
res += ", ";
|
||||||
|
res += std::to_string(shapes[i]);
|
||||||
|
}
|
||||||
|
res+= ">";
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string repr() const {
|
||||||
|
switch(id_) {
|
||||||
|
case VoidTyID: return "void";
|
||||||
|
case HalfTyID: return "f16";
|
||||||
|
case FloatTyID: return "f32";
|
||||||
|
case DoubleTyID: return "f64";
|
||||||
|
case X86_FP80TyID: return "f80";
|
||||||
|
case FP128TyID: return "f128";
|
||||||
|
case PPC_FP128TyID: return "ppcf128";
|
||||||
|
case LabelTyID: return "label";
|
||||||
|
case MetadataTyID: return "md";
|
||||||
|
case TokenTyID: return "tok";
|
||||||
|
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth());
|
||||||
|
case FunctionTyID: return "fn";
|
||||||
|
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||||
|
case StructTyID: return "struct";
|
||||||
|
case TileTyID: return tile_repr();
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
assert(false);
|
||||||
|
return "";
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
context &ctx_;
|
context &ctx_;
|
||||||
|
@@ -148,20 +148,10 @@ DEFINE_UNARY_FLOAT(fneg)
|
|||||||
value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
|
value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
|
||||||
value *rhs, const std::string &name,
|
value *rhs, const std::string &name,
|
||||||
bool has_nuw, bool has_nsw) {
|
bool has_nuw, bool has_nsw) {
|
||||||
auto *clhs = dynamic_cast<constant_int*>(lhs);
|
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);
|
||||||
auto *crhs = dynamic_cast<constant_int*>(rhs);
|
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||||
if(clhs && crhs){
|
if (has_nsw) result->set_has_no_signed_wrap();
|
||||||
constant_expression* result = constant_expression::create(op, clhs, crhs);
|
return result;
|
||||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
|
||||||
if (has_nsw) result->set_has_no_signed_wrap();
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);
|
|
||||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
|
||||||
if (has_nsw) result->set_has_no_signed_wrap();
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
|
#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
|
||||||
|
@@ -120,41 +120,6 @@ metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector<u
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// constant expression
|
|
||||||
constant_expression::constant_expression(op_t op, constant_int* lhs, constant_int* rhs)
|
|
||||||
: constant_int(lhs->get_type(), 0),
|
|
||||||
op_(op), lhs_(lhs), rhs_(rhs) { }
|
|
||||||
|
|
||||||
|
|
||||||
constant_expression *constant_expression::create(op_t op, constant_int* lhs, constant_int* rhs) {
|
|
||||||
context_impl *impl = lhs->get_type()->get_context().p_impl.get();
|
|
||||||
constant_expression *& result = impl->expr_constants_[std::make_tuple((int)op, lhs, rhs)];
|
|
||||||
if(!result)
|
|
||||||
result = new constant_expression(op, lhs, rhs);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t constant_expression::get_value() const {
|
|
||||||
uint64_t lhs = lhs_->get_value();
|
|
||||||
uint64_t rhs = rhs_->get_value();
|
|
||||||
switch(op_) {
|
|
||||||
case op_t::Add : return lhs + rhs;
|
|
||||||
case op_t::Sub : return lhs - rhs;
|
|
||||||
case op_t::Mul : return lhs * rhs;
|
|
||||||
case op_t::UDiv : return lhs / rhs;
|
|
||||||
case op_t::SDiv : return lhs / rhs;
|
|
||||||
case op_t::URem : return lhs % rhs;
|
|
||||||
case op_t::SRem : return lhs % rhs;
|
|
||||||
case op_t::Shl : return lhs << rhs;
|
|
||||||
case op_t::LShr : return lhs >> rhs;
|
|
||||||
case op_t::AShr : return lhs >> rhs;
|
|
||||||
case op_t::And : return lhs && rhs;
|
|
||||||
case op_t::Or : return lhs || rhs;
|
|
||||||
case op_t::Xor : return lhs ^ rhs;
|
|
||||||
default: throw std::runtime_error("unsupported constexpr binary operator");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// undef value
|
// undef value
|
||||||
undef_value::undef_value(type *ty)
|
undef_value::undef_value(type *ty)
|
||||||
|
@@ -482,18 +482,6 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask
|
|||||||
// retile_inst classes
|
// retile_inst classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
std::string retile_inst::shape_suffix(ir::type* ty){
|
|
||||||
std::string res = "[";
|
|
||||||
const auto& shapes = ty->get_tile_shapes();
|
|
||||||
for(unsigned i = 0; i < shapes.size(); i++){
|
|
||||||
res += std::to_string(ty->get_tile_shapes()[i]);
|
|
||||||
if(i < shapes.size() - 1)
|
|
||||||
res += ", ";
|
|
||||||
}
|
|
||||||
res += "]";
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
retile_inst::retile_inst(value *arg, const type::tile_shapes_t &shapes,
|
retile_inst::retile_inst(value *arg, const type::tile_shapes_t &shapes,
|
||||||
const std::string &name, instruction *next)
|
const std::string &name, instruction *next)
|
||||||
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { }
|
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { }
|
||||||
|
@@ -44,14 +44,15 @@ void print(module &mod, std::ostream& os) {
|
|||||||
else
|
else
|
||||||
os << " = ";
|
os << " = ";
|
||||||
}
|
}
|
||||||
os << inst->repr();
|
ir::type* type = inst->get_type();
|
||||||
|
os << inst->repr() << " " << type->repr();
|
||||||
ir::instruction::ops_t ops = inst->ops();
|
ir::instruction::ops_t ops = inst->ops();
|
||||||
size_t num_ops = inst->get_num_operands();
|
size_t num_ops = inst->get_num_operands();
|
||||||
if(num_ops > 0)
|
if(num_ops > 0)
|
||||||
os << " ";;
|
os << " ";;
|
||||||
for(unsigned i = 0; i < num_ops; i++){
|
for(unsigned i = 0; i < num_ops; i++){
|
||||||
if(auto *x = dynamic_cast<ir::constant_int*>(ops[i]))
|
if(auto *x = dynamic_cast<ir::constant*>(ops[i]))
|
||||||
os << x->get_value();
|
os << x->repr();
|
||||||
else
|
else
|
||||||
os << get_name(ops[i], cnt++);
|
os << get_name(ops[i], cnt++);
|
||||||
os << (i < num_ops - 1?", ":"");
|
os << (i < num_ops - 1?", ":"");
|
||||||
|
Reference in New Issue
Block a user