simple constexpr

This commit is contained in:
Philippe Tillet
2019-08-05 13:06:56 -07:00
parent d869d9a924
commit 899b2b72e1
9 changed files with 96 additions and 21 deletions

View File

@@ -3,6 +3,7 @@
#include "value.h"
#include <cassert>
#include "llvm/IR/Instructions.h"
namespace triton{
namespace ir{
@@ -36,14 +37,14 @@ protected:
constant_int(type *ty, uint64_t value);
public:
uint64_t get_value() const { return value_; }
virtual uint64_t get_value() const { return value_; }
static constant_int *get(type *ty, uint64_t value);
protected:
uint64_t value_;
};
/* Metaparameter int */
/* Metaparameter (int) */
class metaparameter: public constant_int {
private:
metaparameter(type *ty, const std::vector<unsigned>& space);
@@ -55,12 +56,36 @@ public:
bool has_value() { return has_value_; }
const std::vector<unsigned>& get_space() { return space_; }
void set_space(const std::vector<unsigned> &space) { space_ = space; }
uint64_t get_value() const { assert(has_value_); return value_; }
private:
std::vector<unsigned> space_;
bool has_value_;
};
class constant_expression: public constant_int {
typedef llvm::BinaryOperator::BinaryOps op_t;
using llop = llvm::BinaryOperator::BinaryOps;
private:
constant_expression(op_t op, constant_int* lhs, constant_int* rhs);
public:
uint64_t get_value() const;
// Wraps
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
// Factory
static constant_expression *create(op_t op, constant_int* lhs, constant_int* rhs);
private:
op_t op_;
constant_int* lhs_;
constant_int* rhs_;
bool has_no_unsigned_wrap_;
bool has_no_signed_wrap_;
};
/* constant range */
class constant_range: public constant{
constant_range(type *ty, constant_int* first, constant_int* last);

View File

@@ -9,6 +9,8 @@ namespace triton{
namespace ir{
class context;
class constant;
class constant_expression;
class constant_int;
class constant_fp;
class undef_value;
@@ -36,6 +38,8 @@ public:
std::map<type*, undef_value*> uv_constants_;
// Metaparameters
std::vector<metaparameter*> mp_constants_;
// Expr constants
std::map<std::tuple<int, constant*, constant*>, constant_expression*> expr_constants_;
};
}

View File

@@ -93,7 +93,6 @@ private:
//===----------------------------------------------------------------------===//
// binary_operator classes
//===----------------------------------------------------------------------===//
class binary_operator: public instruction{
public:
typedef llvm::BinaryOperator::BinaryOps op_t;

View File

@@ -93,7 +93,7 @@ abstract_declarator
;
direct_abstract_declarator
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
: '[' constant_expression_list ']' { $$ = new tile(nullptr, $2); }
type_name
: declaration_specifiers { $$ = new type_name($1, nullptr); }
@@ -133,7 +133,7 @@ builtin_expression
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
| ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); }
| ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); }
| RESHAPE '(' expression ',' primary_expression_list ')' { $$ = new reshape_expression($3, $5); }
| RESHAPE '(' expression ',' constant_expression_list ')' { $$ = new reshape_expression($3, $5); }
;
/* Primary */
@@ -146,11 +146,6 @@ primary_expression
| '(' expression ')' { $$ = $2; }
;
primary_expression_list
: primary_expression { $$ = new list<expression*>((expression*)$1); }
| primary_expression_list ',' primary_expression { $$ = append_ptr_list<expression>($1, $3); }
;
/* Postfix */
slice
: ':' { $$ = new slice(triton::lang::ALL); }
@@ -279,6 +274,10 @@ expression
: assignment_expression { $$ = $1; }
;
constant_expression_list
: expression { $$ = new list<expression*>((expression*)$1); }
| constant_expression_list ',' expression { $$ = append_ptr_list<expression>($1, $3); }
/* Initialization */
initialization_expression
: assignment_expression { $$ = $1; }
@@ -338,7 +337,7 @@ jump_statement
direct_declarator
: identifier { $$ = $1; }
| identifier '[' primary_expression_list ']' { $$ = new tile($1, $3); }
| identifier '[' constant_expression_list ']' { $$ = new tile($1, $3); }
| identifier '(' parameter_list ')' { $$ = new function($1, $3); }
| identifier '(' ')' { $$ = new function($1, nullptr); }
;

View File

@@ -72,7 +72,7 @@ public:
void target_independent(ir::module &module) {
optimize_dot.run(module);
optimize_trans.run(module);
// optimize_dce.run(module);
optimize_dce.run(module);
}
void target_dependent(ir::module &module) {
@@ -86,7 +86,7 @@ public:
shmem_barriers.run(module);
}
vectorize.run(module);
// optimize_dce.run(module);
optimize_dce.run(module);
// ir::print(module, std::cout);
}

View File

@@ -148,11 +148,21 @@ DEFINE_UNARY_FLOAT(fneg)
value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs,
value *rhs, const std::string &name,
bool has_nuw, bool has_nsw) {
if(auto *clhs = dynamic_cast<constant_int*>(lhs)){
if(auto *crhs = dynamic_cast<constant_int*>(rhs)){
constant_expression* result = constant_expression::create(op, clhs, crhs);
if (has_nuw) result->set_has_no_unsigned_wrap();
if (has_nsw) result->set_has_no_signed_wrap();
return result;
}
}
else {
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);
if (has_nuw) result->set_has_no_unsigned_wrap();
if (has_nsw) result->set_has_no_signed_wrap();
return result;
}
}
#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
value* builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw){\
@@ -161,7 +171,7 @@ value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs
#define DEFINE_BINARY_INT(SUFFIX, OPCODE)\
value *builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name){\
return insert(binary_operator::create(OPCODE, lhs, rhs), name);\
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, false, false);\
}
#define DEFINE_UNARY_INT(SUFFIX)\

View File

@@ -127,6 +127,42 @@ metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector<u
return result;
}
// constant expression
constant_expression::constant_expression(op_t op, constant_int* lhs, constant_int* rhs)
: constant_int(lhs->get_type(), 0),
op_(op), lhs_(lhs), rhs_(rhs) { }
constant_expression *constant_expression::create(op_t op, constant_int* lhs, constant_int* rhs) {
context_impl *impl = lhs->get_type()->get_context().p_impl.get();
constant_expression *& result = impl->expr_constants_[std::make_tuple((int)op, lhs, rhs)];
if(!result)
result = new constant_expression(op, lhs, rhs);
return result;
}
uint64_t constant_expression::get_value() const {
uint64_t lhs = lhs_->get_value();
uint64_t rhs = rhs_->get_value();
switch(op_) {
case llop::Add : return lhs + rhs;
case llop::Sub : return lhs - rhs;
case llop::Mul : return lhs * rhs;
case llop::UDiv : return lhs / rhs;
case llop::SDiv : return lhs / rhs;
case llop::URem : return lhs % rhs;
case llop::SRem : return lhs % rhs;
case llop::Shl : return lhs << rhs;
case llop::LShr : return lhs >> rhs;
case llop::AShr : return lhs >> rhs;
case llop::And : return lhs && rhs;
case llop::Or : return lhs || rhs;
case llop::Xor : return lhs ^ rhs;
default: throw std::runtime_error("unsupported constexpr binary operator");
}
}
// undef value
undef_value::undef_value(type *ty)
: constant(ty, 0) { }

View File

@@ -79,7 +79,8 @@ ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_cons
ir::type::tile_shapes_t shapes;
for(expression *expr: shapes_->values()){
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
assert(shape);
if(shape == nullptr)
throw std::runtime_error("tile shapes must be constant expressions");
shapes.push_back(shape);
}
return ir::tile_type::get(type, shapes);

View File

@@ -101,6 +101,7 @@ ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir:
ir::value* binary_expression::codegen(ir::module *mod) const{
ir::value *lhs = lhs_->codegen(mod);
ir::value *rhs = rhs_->codegen(mod);
std::cout << " " << typeid(*lhs_).name() << " " << typeid(*rhs_).name() << std::endl;
ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, "");
return result;
}
@@ -169,7 +170,8 @@ ir::value* reshape_expression::codegen(ir::module *mod) const {
ir::type::tile_shapes_t shapes;
for(expression *expr: shapes_->values()){
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
assert(shape);
if(shape == nullptr)
throw std::runtime_error("tile shapes must be constant expressions");
shapes.push_back(shape);
}
// return
@@ -210,7 +212,6 @@ ir::value* sqrt_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_sqrt(arg_->codegen(mod));
}
// reduce
ir::value* reduce_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_reduce(arg_->codegen(mod), axis_->value());