simple constexpr
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "value.h"
|
||||
#include <cassert>
|
||||
#include "llvm/IR/Instructions.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
@@ -36,14 +37,14 @@ protected:
|
||||
constant_int(type *ty, uint64_t value);
|
||||
|
||||
public:
|
||||
uint64_t get_value() const { return value_; }
|
||||
virtual uint64_t get_value() const { return value_; }
|
||||
static constant_int *get(type *ty, uint64_t value);
|
||||
|
||||
protected:
|
||||
uint64_t value_;
|
||||
};
|
||||
|
||||
/* Metaparameter int */
|
||||
/* Metaparameter (int) */
|
||||
class metaparameter: public constant_int {
|
||||
private:
|
||||
metaparameter(type *ty, const std::vector<unsigned>& space);
|
||||
@@ -55,12 +56,36 @@ public:
|
||||
bool has_value() { return has_value_; }
|
||||
const std::vector<unsigned>& get_space() { return space_; }
|
||||
void set_space(const std::vector<unsigned> &space) { space_ = space; }
|
||||
uint64_t get_value() const { assert(has_value_); return value_; }
|
||||
|
||||
private:
|
||||
std::vector<unsigned> space_;
|
||||
bool has_value_;
|
||||
};
|
||||
|
||||
class constant_expression: public constant_int {
|
||||
typedef llvm::BinaryOperator::BinaryOps op_t;
|
||||
using llop = llvm::BinaryOperator::BinaryOps;
|
||||
|
||||
private:
|
||||
constant_expression(op_t op, constant_int* lhs, constant_int* rhs);
|
||||
|
||||
public:
|
||||
uint64_t get_value() const;
|
||||
// Wraps
|
||||
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
||||
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
|
||||
// Factory
|
||||
static constant_expression *create(op_t op, constant_int* lhs, constant_int* rhs);
|
||||
|
||||
private:
|
||||
op_t op_;
|
||||
constant_int* lhs_;
|
||||
constant_int* rhs_;
|
||||
bool has_no_unsigned_wrap_;
|
||||
bool has_no_signed_wrap_;
|
||||
};
|
||||
|
||||
/* constant range */
|
||||
class constant_range: public constant{
|
||||
constant_range(type *ty, constant_int* first, constant_int* last);
|
||||
|
@@ -9,6 +9,8 @@ namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class context;
|
||||
class constant;
|
||||
class constant_expression;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class undef_value;
|
||||
@@ -36,6 +38,8 @@ public:
|
||||
std::map<type*, undef_value*> uv_constants_;
|
||||
// Metaparameters
|
||||
std::vector<metaparameter*> mp_constants_;
|
||||
// Expr constants
|
||||
std::map<std::tuple<int, constant*, constant*>, constant_expression*> expr_constants_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -93,7 +93,6 @@ private:
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary_operator classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class binary_operator: public instruction{
|
||||
public:
|
||||
typedef llvm::BinaryOperator::BinaryOps op_t;
|
||||
|
@@ -93,7 +93,7 @@ abstract_declarator
|
||||
;
|
||||
|
||||
direct_abstract_declarator
|
||||
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
|
||||
: '[' constant_expression_list ']' { $$ = new tile(nullptr, $2); }
|
||||
|
||||
type_name
|
||||
: declaration_specifiers { $$ = new type_name($1, nullptr); }
|
||||
@@ -133,7 +133,7 @@ builtin_expression
|
||||
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
|
||||
| ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); }
|
||||
| ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); }
|
||||
| RESHAPE '(' expression ',' primary_expression_list ')' { $$ = new reshape_expression($3, $5); }
|
||||
| RESHAPE '(' expression ',' constant_expression_list ')' { $$ = new reshape_expression($3, $5); }
|
||||
;
|
||||
|
||||
/* Primary */
|
||||
@@ -146,11 +146,6 @@ primary_expression
|
||||
| '(' expression ')' { $$ = $2; }
|
||||
;
|
||||
|
||||
primary_expression_list
|
||||
: primary_expression { $$ = new list<expression*>((expression*)$1); }
|
||||
| primary_expression_list ',' primary_expression { $$ = append_ptr_list<expression>($1, $3); }
|
||||
;
|
||||
|
||||
/* Postfix */
|
||||
slice
|
||||
: ':' { $$ = new slice(triton::lang::ALL); }
|
||||
@@ -279,6 +274,10 @@ expression
|
||||
: assignment_expression { $$ = $1; }
|
||||
;
|
||||
|
||||
constant_expression_list
|
||||
: expression { $$ = new list<expression*>((expression*)$1); }
|
||||
| constant_expression_list ',' expression { $$ = append_ptr_list<expression>($1, $3); }
|
||||
|
||||
/* Initialization */
|
||||
initialization_expression
|
||||
: assignment_expression { $$ = $1; }
|
||||
@@ -338,7 +337,7 @@ jump_statement
|
||||
|
||||
direct_declarator
|
||||
: identifier { $$ = $1; }
|
||||
| identifier '[' primary_expression_list ']' { $$ = new tile($1, $3); }
|
||||
| identifier '[' constant_expression_list ']' { $$ = new tile($1, $3); }
|
||||
| identifier '(' parameter_list ')' { $$ = new function($1, $3); }
|
||||
| identifier '(' ')' { $$ = new function($1, nullptr); }
|
||||
;
|
||||
|
@@ -72,7 +72,7 @@ public:
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
// optimize_dce.run(module);
|
||||
optimize_dce.run(module);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
@@ -86,7 +86,7 @@ public:
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
vectorize.run(module);
|
||||
// optimize_dce.run(module);
|
||||
optimize_dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
|
@@ -148,11 +148,21 @@ DEFINE_UNARY_FLOAT(fneg)
|
||||
value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs,
|
||||
value *rhs, const std::string &name,
|
||||
bool has_nuw, bool has_nsw) {
|
||||
if(auto *clhs = dynamic_cast<constant_int*>(lhs)){
|
||||
if(auto *crhs = dynamic_cast<constant_int*>(rhs)){
|
||||
constant_expression* result = constant_expression::create(op, clhs, crhs);
|
||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||
if (has_nsw) result->set_has_no_signed_wrap();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
else {
|
||||
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);
|
||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||
if (has_nsw) result->set_has_no_signed_wrap();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
|
||||
value* builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw){\
|
||||
@@ -161,7 +171,7 @@ value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs
|
||||
|
||||
#define DEFINE_BINARY_INT(SUFFIX, OPCODE)\
|
||||
value *builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name){\
|
||||
return insert(binary_operator::create(OPCODE, lhs, rhs), name);\
|
||||
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, false, false);\
|
||||
}
|
||||
|
||||
#define DEFINE_UNARY_INT(SUFFIX)\
|
||||
|
@@ -127,6 +127,42 @@ metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector<u
|
||||
return result;
|
||||
}
|
||||
|
||||
// constant expression
|
||||
constant_expression::constant_expression(op_t op, constant_int* lhs, constant_int* rhs)
|
||||
: constant_int(lhs->get_type(), 0),
|
||||
op_(op), lhs_(lhs), rhs_(rhs) { }
|
||||
|
||||
|
||||
constant_expression *constant_expression::create(op_t op, constant_int* lhs, constant_int* rhs) {
|
||||
context_impl *impl = lhs->get_type()->get_context().p_impl.get();
|
||||
constant_expression *& result = impl->expr_constants_[std::make_tuple((int)op, lhs, rhs)];
|
||||
if(!result)
|
||||
result = new constant_expression(op, lhs, rhs);
|
||||
return result;
|
||||
}
|
||||
|
||||
uint64_t constant_expression::get_value() const {
|
||||
uint64_t lhs = lhs_->get_value();
|
||||
uint64_t rhs = rhs_->get_value();
|
||||
switch(op_) {
|
||||
case llop::Add : return lhs + rhs;
|
||||
case llop::Sub : return lhs - rhs;
|
||||
case llop::Mul : return lhs * rhs;
|
||||
case llop::UDiv : return lhs / rhs;
|
||||
case llop::SDiv : return lhs / rhs;
|
||||
case llop::URem : return lhs % rhs;
|
||||
case llop::SRem : return lhs % rhs;
|
||||
case llop::Shl : return lhs << rhs;
|
||||
case llop::LShr : return lhs >> rhs;
|
||||
case llop::AShr : return lhs >> rhs;
|
||||
case llop::And : return lhs && rhs;
|
||||
case llop::Or : return lhs || rhs;
|
||||
case llop::Xor : return lhs ^ rhs;
|
||||
default: throw std::runtime_error("unsupported constexpr binary operator");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// undef value
|
||||
undef_value::undef_value(type *ty)
|
||||
: constant(ty, 0) { }
|
||||
|
@@ -79,7 +79,8 @@ ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_cons
|
||||
ir::type::tile_shapes_t shapes;
|
||||
for(expression *expr: shapes_->values()){
|
||||
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
|
||||
assert(shape);
|
||||
if(shape == nullptr)
|
||||
throw std::runtime_error("tile shapes must be constant expressions");
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
return ir::tile_type::get(type, shapes);
|
||||
|
@@ -101,6 +101,7 @@ ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir:
|
||||
ir::value* binary_expression::codegen(ir::module *mod) const{
|
||||
ir::value *lhs = lhs_->codegen(mod);
|
||||
ir::value *rhs = rhs_->codegen(mod);
|
||||
std::cout << " " << typeid(*lhs_).name() << " " << typeid(*rhs_).name() << std::endl;
|
||||
ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, "");
|
||||
return result;
|
||||
}
|
||||
@@ -169,7 +170,8 @@ ir::value* reshape_expression::codegen(ir::module *mod) const {
|
||||
ir::type::tile_shapes_t shapes;
|
||||
for(expression *expr: shapes_->values()){
|
||||
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
|
||||
assert(shape);
|
||||
if(shape == nullptr)
|
||||
throw std::runtime_error("tile shapes must be constant expressions");
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
// return
|
||||
@@ -210,7 +212,6 @@ ir::value* sqrt_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_sqrt(arg_->codegen(mod));
|
||||
}
|
||||
|
||||
|
||||
// reduce
|
||||
ir::value* reduce_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_reduce(arg_->codegen(mod), axis_->value());
|
||||
|
Reference in New Issue
Block a user