simple constexpr

This commit is contained in:
Philippe Tillet
2019-08-05 13:06:56 -07:00
parent d869d9a924
commit 899b2b72e1
9 changed files with 96 additions and 21 deletions

View File

@@ -3,6 +3,7 @@
#include "value.h" #include "value.h"
#include <cassert> #include <cassert>
#include "llvm/IR/Instructions.h"
namespace triton{ namespace triton{
namespace ir{ namespace ir{
@@ -36,14 +37,14 @@ protected:
constant_int(type *ty, uint64_t value); constant_int(type *ty, uint64_t value);
public: public:
uint64_t get_value() const { return value_; } virtual uint64_t get_value() const { return value_; }
static constant_int *get(type *ty, uint64_t value); static constant_int *get(type *ty, uint64_t value);
protected: protected:
uint64_t value_; uint64_t value_;
}; };
/* Metaparameter int */ /* Metaparameter (int) */
class metaparameter: public constant_int { class metaparameter: public constant_int {
private: private:
metaparameter(type *ty, const std::vector<unsigned>& space); metaparameter(type *ty, const std::vector<unsigned>& space);
@@ -55,12 +56,36 @@ public:
bool has_value() { return has_value_; } bool has_value() { return has_value_; }
const std::vector<unsigned>& get_space() { return space_; } const std::vector<unsigned>& get_space() { return space_; }
void set_space(const std::vector<unsigned> &space) { space_ = space; } void set_space(const std::vector<unsigned> &space) { space_ = space; }
uint64_t get_value() const { assert(has_value_); return value_; }
private: private:
std::vector<unsigned> space_; std::vector<unsigned> space_;
bool has_value_; bool has_value_;
}; };
class constant_expression: public constant_int {
typedef llvm::BinaryOperator::BinaryOps op_t;
using llop = llvm::BinaryOperator::BinaryOps;
private:
constant_expression(op_t op, constant_int* lhs, constant_int* rhs);
public:
uint64_t get_value() const;
// Wraps
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
// Factory
static constant_expression *create(op_t op, constant_int* lhs, constant_int* rhs);
private:
op_t op_;
constant_int* lhs_;
constant_int* rhs_;
bool has_no_unsigned_wrap_;
bool has_no_signed_wrap_;
};
/* constant range */ /* constant range */
class constant_range: public constant{ class constant_range: public constant{
constant_range(type *ty, constant_int* first, constant_int* last); constant_range(type *ty, constant_int* first, constant_int* last);

View File

@@ -9,6 +9,8 @@ namespace triton{
namespace ir{ namespace ir{
class context; class context;
class constant;
class constant_expression;
class constant_int; class constant_int;
class constant_fp; class constant_fp;
class undef_value; class undef_value;
@@ -36,6 +38,8 @@ public:
std::map<type*, undef_value*> uv_constants_; std::map<type*, undef_value*> uv_constants_;
// Metaparameters // Metaparameters
std::vector<metaparameter*> mp_constants_; std::vector<metaparameter*> mp_constants_;
// Expr constants
std::map<std::tuple<int, constant*, constant*>, constant_expression*> expr_constants_;
}; };
} }

View File

@@ -93,7 +93,6 @@ private:
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// binary_operator classes // binary_operator classes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class binary_operator: public instruction{ class binary_operator: public instruction{
public: public:
typedef llvm::BinaryOperator::BinaryOps op_t; typedef llvm::BinaryOperator::BinaryOps op_t;

View File

@@ -93,7 +93,7 @@ abstract_declarator
; ;
direct_abstract_declarator direct_abstract_declarator
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); } : '[' constant_expression_list ']' { $$ = new tile(nullptr, $2); }
type_name type_name
: declaration_specifiers { $$ = new type_name($1, nullptr); } : declaration_specifiers { $$ = new type_name($1, nullptr); }
@@ -133,7 +133,7 @@ builtin_expression
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); } | ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
| ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); } | ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); }
| ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); } | ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); }
| RESHAPE '(' expression ',' primary_expression_list ')' { $$ = new reshape_expression($3, $5); } | RESHAPE '(' expression ',' constant_expression_list ')' { $$ = new reshape_expression($3, $5); }
; ;
/* Primary */ /* Primary */
@@ -146,11 +146,6 @@ primary_expression
| '(' expression ')' { $$ = $2; } | '(' expression ')' { $$ = $2; }
; ;
primary_expression_list
: primary_expression { $$ = new list<expression*>((expression*)$1); }
| primary_expression_list ',' primary_expression { $$ = append_ptr_list<expression>($1, $3); }
;
/* Postfix */ /* Postfix */
slice slice
: ':' { $$ = new slice(triton::lang::ALL); } : ':' { $$ = new slice(triton::lang::ALL); }
@@ -279,6 +274,10 @@ expression
: assignment_expression { $$ = $1; } : assignment_expression { $$ = $1; }
; ;
constant_expression_list
: expression { $$ = new list<expression*>((expression*)$1); }
| constant_expression_list ',' expression { $$ = append_ptr_list<expression>($1, $3); }
/* Initialization */ /* Initialization */
initialization_expression initialization_expression
: assignment_expression { $$ = $1; } : assignment_expression { $$ = $1; }
@@ -338,7 +337,7 @@ jump_statement
direct_declarator direct_declarator
: identifier { $$ = $1; } : identifier { $$ = $1; }
| identifier '[' primary_expression_list ']' { $$ = new tile($1, $3); } | identifier '[' constant_expression_list ']' { $$ = new tile($1, $3); }
| identifier '(' parameter_list ')' { $$ = new function($1, $3); } | identifier '(' parameter_list ')' { $$ = new function($1, $3); }
| identifier '(' ')' { $$ = new function($1, nullptr); } | identifier '(' ')' { $$ = new function($1, nullptr); }
; ;

View File

@@ -72,7 +72,7 @@ public:
void target_independent(ir::module &module) { void target_independent(ir::module &module) {
optimize_dot.run(module); optimize_dot.run(module);
optimize_trans.run(module); optimize_trans.run(module);
// optimize_dce.run(module); optimize_dce.run(module);
} }
void target_dependent(ir::module &module) { void target_dependent(ir::module &module) {
@@ -86,7 +86,7 @@ public:
shmem_barriers.run(module); shmem_barriers.run(module);
} }
vectorize.run(module); vectorize.run(module);
// optimize_dce.run(module); optimize_dce.run(module);
// ir::print(module, std::cout); // ir::print(module, std::cout);
} }

View File

@@ -148,10 +148,20 @@ DEFINE_UNARY_FLOAT(fneg)
value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs, value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs,
value *rhs, const std::string &name, value *rhs, const std::string &name,
bool has_nuw, bool has_nsw) { bool has_nuw, bool has_nsw) {
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name); if(auto *clhs = dynamic_cast<constant_int*>(lhs)){
if (has_nuw) result->set_has_no_unsigned_wrap(); if(auto *crhs = dynamic_cast<constant_int*>(rhs)){
if (has_nsw) result->set_has_no_signed_wrap(); constant_expression* result = constant_expression::create(op, clhs, crhs);
return result; if (has_nuw) result->set_has_no_unsigned_wrap();
if (has_nsw) result->set_has_no_signed_wrap();
return result;
}
}
else {
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);
if (has_nuw) result->set_has_no_unsigned_wrap();
if (has_nsw) result->set_has_no_signed_wrap();
return result;
}
} }
#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\ #define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
@@ -161,7 +171,7 @@ value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs
#define DEFINE_BINARY_INT(SUFFIX, OPCODE)\ #define DEFINE_BINARY_INT(SUFFIX, OPCODE)\
value *builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name){\ value *builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name){\
return insert(binary_operator::create(OPCODE, lhs, rhs), name);\ return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, false, false);\
} }
#define DEFINE_UNARY_INT(SUFFIX)\ #define DEFINE_UNARY_INT(SUFFIX)\

View File

@@ -127,6 +127,42 @@ metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector<u
return result; return result;
} }
// constant expression
constant_expression::constant_expression(op_t op, constant_int* lhs, constant_int* rhs)
: constant_int(lhs->get_type(), 0),
op_(op), lhs_(lhs), rhs_(rhs) { }
constant_expression *constant_expression::create(op_t op, constant_int* lhs, constant_int* rhs) {
context_impl *impl = lhs->get_type()->get_context().p_impl.get();
constant_expression *& result = impl->expr_constants_[std::make_tuple((int)op, lhs, rhs)];
if(!result)
result = new constant_expression(op, lhs, rhs);
return result;
}
uint64_t constant_expression::get_value() const {
uint64_t lhs = lhs_->get_value();
uint64_t rhs = rhs_->get_value();
switch(op_) {
case llop::Add : return lhs + rhs;
case llop::Sub : return lhs - rhs;
case llop::Mul : return lhs * rhs;
case llop::UDiv : return lhs / rhs;
case llop::SDiv : return lhs / rhs;
case llop::URem : return lhs % rhs;
case llop::SRem : return lhs % rhs;
case llop::Shl : return lhs << rhs;
case llop::LShr : return lhs >> rhs;
case llop::AShr : return lhs >> rhs;
case llop::And : return lhs && rhs;
case llop::Or : return lhs || rhs;
case llop::Xor : return lhs ^ rhs;
default: throw std::runtime_error("unsupported constexpr binary operator");
}
}
// undef value // undef value
undef_value::undef_value(type *ty) undef_value::undef_value(type *ty)
: constant(ty, 0) { } : constant(ty, 0) { }

View File

@@ -79,7 +79,8 @@ ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_cons
ir::type::tile_shapes_t shapes; ir::type::tile_shapes_t shapes;
for(expression *expr: shapes_->values()){ for(expression *expr: shapes_->values()){
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod)); ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
assert(shape); if(shape == nullptr)
throw std::runtime_error("tile shapes must be constant expressions");
shapes.push_back(shape); shapes.push_back(shape);
} }
return ir::tile_type::get(type, shapes); return ir::tile_type::get(type, shapes);

View File

@@ -101,6 +101,7 @@ ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir:
ir::value* binary_expression::codegen(ir::module *mod) const{ ir::value* binary_expression::codegen(ir::module *mod) const{
ir::value *lhs = lhs_->codegen(mod); ir::value *lhs = lhs_->codegen(mod);
ir::value *rhs = rhs_->codegen(mod); ir::value *rhs = rhs_->codegen(mod);
std::cout << " " << typeid(*lhs_).name() << " " << typeid(*rhs_).name() << std::endl;
ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, ""); ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, "");
return result; return result;
} }
@@ -169,7 +170,8 @@ ir::value* reshape_expression::codegen(ir::module *mod) const {
ir::type::tile_shapes_t shapes; ir::type::tile_shapes_t shapes;
for(expression *expr: shapes_->values()){ for(expression *expr: shapes_->values()){
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod)); ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
assert(shape); if(shape == nullptr)
throw std::runtime_error("tile shapes must be constant expressions");
shapes.push_back(shape); shapes.push_back(shape);
} }
// return // return
@@ -210,7 +212,6 @@ ir::value* sqrt_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_sqrt(arg_->codegen(mod)); return mod->get_builder().create_sqrt(arg_->codegen(mod));
} }
// reduce // reduce
ir::value* reduce_expression::codegen(ir::module *mod) const { ir::value* reduce_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_reduce(arg_->codegen(mod), axis_->value()); return mod->get_builder().create_reduce(arg_->codegen(mod), axis_->value());