simple constexpr
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
#include "value.h"
|
#include "value.h"
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include "llvm/IR/Instructions.h"
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
@@ -36,14 +37,14 @@ protected:
|
|||||||
constant_int(type *ty, uint64_t value);
|
constant_int(type *ty, uint64_t value);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
uint64_t get_value() const { return value_; }
|
virtual uint64_t get_value() const { return value_; }
|
||||||
static constant_int *get(type *ty, uint64_t value);
|
static constant_int *get(type *ty, uint64_t value);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
uint64_t value_;
|
uint64_t value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Metaparameter int */
|
/* Metaparameter (int) */
|
||||||
class metaparameter: public constant_int {
|
class metaparameter: public constant_int {
|
||||||
private:
|
private:
|
||||||
metaparameter(type *ty, const std::vector<unsigned>& space);
|
metaparameter(type *ty, const std::vector<unsigned>& space);
|
||||||
@@ -55,12 +56,36 @@ public:
|
|||||||
bool has_value() { return has_value_; }
|
bool has_value() { return has_value_; }
|
||||||
const std::vector<unsigned>& get_space() { return space_; }
|
const std::vector<unsigned>& get_space() { return space_; }
|
||||||
void set_space(const std::vector<unsigned> &space) { space_ = space; }
|
void set_space(const std::vector<unsigned> &space) { space_ = space; }
|
||||||
|
uint64_t get_value() const { assert(has_value_); return value_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<unsigned> space_;
|
std::vector<unsigned> space_;
|
||||||
bool has_value_;
|
bool has_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class constant_expression: public constant_int {
|
||||||
|
typedef llvm::BinaryOperator::BinaryOps op_t;
|
||||||
|
using llop = llvm::BinaryOperator::BinaryOps;
|
||||||
|
|
||||||
|
private:
|
||||||
|
constant_expression(op_t op, constant_int* lhs, constant_int* rhs);
|
||||||
|
|
||||||
|
public:
|
||||||
|
uint64_t get_value() const;
|
||||||
|
// Wraps
|
||||||
|
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
||||||
|
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
|
||||||
|
// Factory
|
||||||
|
static constant_expression *create(op_t op, constant_int* lhs, constant_int* rhs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
op_t op_;
|
||||||
|
constant_int* lhs_;
|
||||||
|
constant_int* rhs_;
|
||||||
|
bool has_no_unsigned_wrap_;
|
||||||
|
bool has_no_signed_wrap_;
|
||||||
|
};
|
||||||
|
|
||||||
/* constant range */
|
/* constant range */
|
||||||
class constant_range: public constant{
|
class constant_range: public constant{
|
||||||
constant_range(type *ty, constant_int* first, constant_int* last);
|
constant_range(type *ty, constant_int* first, constant_int* last);
|
||||||
|
@@ -9,6 +9,8 @@ namespace triton{
|
|||||||
namespace ir{
|
namespace ir{
|
||||||
|
|
||||||
class context;
|
class context;
|
||||||
|
class constant;
|
||||||
|
class constant_expression;
|
||||||
class constant_int;
|
class constant_int;
|
||||||
class constant_fp;
|
class constant_fp;
|
||||||
class undef_value;
|
class undef_value;
|
||||||
@@ -36,6 +38,8 @@ public:
|
|||||||
std::map<type*, undef_value*> uv_constants_;
|
std::map<type*, undef_value*> uv_constants_;
|
||||||
// Metaparameters
|
// Metaparameters
|
||||||
std::vector<metaparameter*> mp_constants_;
|
std::vector<metaparameter*> mp_constants_;
|
||||||
|
// Expr constants
|
||||||
|
std::map<std::tuple<int, constant*, constant*>, constant_expression*> expr_constants_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -93,7 +93,6 @@ private:
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// binary_operator classes
|
// binary_operator classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class binary_operator: public instruction{
|
class binary_operator: public instruction{
|
||||||
public:
|
public:
|
||||||
typedef llvm::BinaryOperator::BinaryOps op_t;
|
typedef llvm::BinaryOperator::BinaryOps op_t;
|
||||||
|
@@ -93,7 +93,7 @@ abstract_declarator
|
|||||||
;
|
;
|
||||||
|
|
||||||
direct_abstract_declarator
|
direct_abstract_declarator
|
||||||
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
|
: '[' constant_expression_list ']' { $$ = new tile(nullptr, $2); }
|
||||||
|
|
||||||
type_name
|
type_name
|
||||||
: declaration_specifiers { $$ = new type_name($1, nullptr); }
|
: declaration_specifiers { $$ = new type_name($1, nullptr); }
|
||||||
@@ -133,7 +133,7 @@ builtin_expression
|
|||||||
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
|
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
|
||||||
| ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); }
|
| ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); }
|
||||||
| ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); }
|
| ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); }
|
||||||
| RESHAPE '(' expression ',' primary_expression_list ')' { $$ = new reshape_expression($3, $5); }
|
| RESHAPE '(' expression ',' constant_expression_list ')' { $$ = new reshape_expression($3, $5); }
|
||||||
;
|
;
|
||||||
|
|
||||||
/* Primary */
|
/* Primary */
|
||||||
@@ -146,11 +146,6 @@ primary_expression
|
|||||||
| '(' expression ')' { $$ = $2; }
|
| '(' expression ')' { $$ = $2; }
|
||||||
;
|
;
|
||||||
|
|
||||||
primary_expression_list
|
|
||||||
: primary_expression { $$ = new list<expression*>((expression*)$1); }
|
|
||||||
| primary_expression_list ',' primary_expression { $$ = append_ptr_list<expression>($1, $3); }
|
|
||||||
;
|
|
||||||
|
|
||||||
/* Postfix */
|
/* Postfix */
|
||||||
slice
|
slice
|
||||||
: ':' { $$ = new slice(triton::lang::ALL); }
|
: ':' { $$ = new slice(triton::lang::ALL); }
|
||||||
@@ -279,6 +274,10 @@ expression
|
|||||||
: assignment_expression { $$ = $1; }
|
: assignment_expression { $$ = $1; }
|
||||||
;
|
;
|
||||||
|
|
||||||
|
constant_expression_list
|
||||||
|
: expression { $$ = new list<expression*>((expression*)$1); }
|
||||||
|
| constant_expression_list ',' expression { $$ = append_ptr_list<expression>($1, $3); }
|
||||||
|
|
||||||
/* Initialization */
|
/* Initialization */
|
||||||
initialization_expression
|
initialization_expression
|
||||||
: assignment_expression { $$ = $1; }
|
: assignment_expression { $$ = $1; }
|
||||||
@@ -338,7 +337,7 @@ jump_statement
|
|||||||
|
|
||||||
direct_declarator
|
direct_declarator
|
||||||
: identifier { $$ = $1; }
|
: identifier { $$ = $1; }
|
||||||
| identifier '[' primary_expression_list ']' { $$ = new tile($1, $3); }
|
| identifier '[' constant_expression_list ']' { $$ = new tile($1, $3); }
|
||||||
| identifier '(' parameter_list ')' { $$ = new function($1, $3); }
|
| identifier '(' parameter_list ')' { $$ = new function($1, $3); }
|
||||||
| identifier '(' ')' { $$ = new function($1, nullptr); }
|
| identifier '(' ')' { $$ = new function($1, nullptr); }
|
||||||
;
|
;
|
||||||
|
@@ -72,7 +72,7 @@ public:
|
|||||||
void target_independent(ir::module &module) {
|
void target_independent(ir::module &module) {
|
||||||
optimize_dot.run(module);
|
optimize_dot.run(module);
|
||||||
optimize_trans.run(module);
|
optimize_trans.run(module);
|
||||||
// optimize_dce.run(module);
|
optimize_dce.run(module);
|
||||||
}
|
}
|
||||||
|
|
||||||
void target_dependent(ir::module &module) {
|
void target_dependent(ir::module &module) {
|
||||||
@@ -86,7 +86,7 @@ public:
|
|||||||
shmem_barriers.run(module);
|
shmem_barriers.run(module);
|
||||||
}
|
}
|
||||||
vectorize.run(module);
|
vectorize.run(module);
|
||||||
// optimize_dce.run(module);
|
optimize_dce.run(module);
|
||||||
// ir::print(module, std::cout);
|
// ir::print(module, std::cout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -148,10 +148,20 @@ DEFINE_UNARY_FLOAT(fneg)
|
|||||||
value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs,
|
value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs,
|
||||||
value *rhs, const std::string &name,
|
value *rhs, const std::string &name,
|
||||||
bool has_nuw, bool has_nsw) {
|
bool has_nuw, bool has_nsw) {
|
||||||
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);
|
if(auto *clhs = dynamic_cast<constant_int*>(lhs)){
|
||||||
if (has_nuw) result->set_has_no_unsigned_wrap();
|
if(auto *crhs = dynamic_cast<constant_int*>(rhs)){
|
||||||
if (has_nsw) result->set_has_no_signed_wrap();
|
constant_expression* result = constant_expression::create(op, clhs, crhs);
|
||||||
return result;
|
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||||
|
if (has_nsw) result->set_has_no_signed_wrap();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);
|
||||||
|
if (has_nuw) result->set_has_no_unsigned_wrap();
|
||||||
|
if (has_nsw) result->set_has_no_signed_wrap();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
|
#define DEFINE_NOWRAP_BINARY(SUFFIX, OPCODE)\
|
||||||
@@ -161,7 +171,7 @@ value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs
|
|||||||
|
|
||||||
#define DEFINE_BINARY_INT(SUFFIX, OPCODE)\
|
#define DEFINE_BINARY_INT(SUFFIX, OPCODE)\
|
||||||
value *builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name){\
|
value *builder::create_ ## SUFFIX(value *lhs, value *rhs, const std::string &name){\
|
||||||
return insert(binary_operator::create(OPCODE, lhs, rhs), name);\
|
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, false, false);\
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DEFINE_UNARY_INT(SUFFIX)\
|
#define DEFINE_UNARY_INT(SUFFIX)\
|
||||||
|
@@ -127,6 +127,42 @@ metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector<u
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// constant expression
|
||||||
|
constant_expression::constant_expression(op_t op, constant_int* lhs, constant_int* rhs)
|
||||||
|
: constant_int(lhs->get_type(), 0),
|
||||||
|
op_(op), lhs_(lhs), rhs_(rhs) { }
|
||||||
|
|
||||||
|
|
||||||
|
constant_expression *constant_expression::create(op_t op, constant_int* lhs, constant_int* rhs) {
|
||||||
|
context_impl *impl = lhs->get_type()->get_context().p_impl.get();
|
||||||
|
constant_expression *& result = impl->expr_constants_[std::make_tuple((int)op, lhs, rhs)];
|
||||||
|
if(!result)
|
||||||
|
result = new constant_expression(op, lhs, rhs);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t constant_expression::get_value() const {
|
||||||
|
uint64_t lhs = lhs_->get_value();
|
||||||
|
uint64_t rhs = rhs_->get_value();
|
||||||
|
switch(op_) {
|
||||||
|
case llop::Add : return lhs + rhs;
|
||||||
|
case llop::Sub : return lhs - rhs;
|
||||||
|
case llop::Mul : return lhs * rhs;
|
||||||
|
case llop::UDiv : return lhs / rhs;
|
||||||
|
case llop::SDiv : return lhs / rhs;
|
||||||
|
case llop::URem : return lhs % rhs;
|
||||||
|
case llop::SRem : return lhs % rhs;
|
||||||
|
case llop::Shl : return lhs << rhs;
|
||||||
|
case llop::LShr : return lhs >> rhs;
|
||||||
|
case llop::AShr : return lhs >> rhs;
|
||||||
|
case llop::And : return lhs && rhs;
|
||||||
|
case llop::Or : return lhs || rhs;
|
||||||
|
case llop::Xor : return lhs ^ rhs;
|
||||||
|
default: throw std::runtime_error("unsupported constexpr binary operator");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// undef value
|
// undef value
|
||||||
undef_value::undef_value(type *ty)
|
undef_value::undef_value(type *ty)
|
||||||
: constant(ty, 0) { }
|
: constant(ty, 0) { }
|
||||||
|
@@ -79,7 +79,8 @@ ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_cons
|
|||||||
ir::type::tile_shapes_t shapes;
|
ir::type::tile_shapes_t shapes;
|
||||||
for(expression *expr: shapes_->values()){
|
for(expression *expr: shapes_->values()){
|
||||||
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
|
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
|
||||||
assert(shape);
|
if(shape == nullptr)
|
||||||
|
throw std::runtime_error("tile shapes must be constant expressions");
|
||||||
shapes.push_back(shape);
|
shapes.push_back(shape);
|
||||||
}
|
}
|
||||||
return ir::tile_type::get(type, shapes);
|
return ir::tile_type::get(type, shapes);
|
||||||
|
@@ -101,6 +101,7 @@ ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir:
|
|||||||
ir::value* binary_expression::codegen(ir::module *mod) const{
|
ir::value* binary_expression::codegen(ir::module *mod) const{
|
||||||
ir::value *lhs = lhs_->codegen(mod);
|
ir::value *lhs = lhs_->codegen(mod);
|
||||||
ir::value *rhs = rhs_->codegen(mod);
|
ir::value *rhs = rhs_->codegen(mod);
|
||||||
|
std::cout << " " << typeid(*lhs_).name() << " " << typeid(*rhs_).name() << std::endl;
|
||||||
ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, "");
|
ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, "");
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@@ -169,7 +170,8 @@ ir::value* reshape_expression::codegen(ir::module *mod) const {
|
|||||||
ir::type::tile_shapes_t shapes;
|
ir::type::tile_shapes_t shapes;
|
||||||
for(expression *expr: shapes_->values()){
|
for(expression *expr: shapes_->values()){
|
||||||
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
|
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
|
||||||
assert(shape);
|
if(shape == nullptr)
|
||||||
|
throw std::runtime_error("tile shapes must be constant expressions");
|
||||||
shapes.push_back(shape);
|
shapes.push_back(shape);
|
||||||
}
|
}
|
||||||
// return
|
// return
|
||||||
@@ -210,7 +212,6 @@ ir::value* sqrt_expression::codegen(ir::module *mod) const {
|
|||||||
return mod->get_builder().create_sqrt(arg_->codegen(mod));
|
return mod->get_builder().create_sqrt(arg_->codegen(mod));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// reduce
|
// reduce
|
||||||
ir::value* reduce_expression::codegen(ir::module *mod) const {
|
ir::value* reduce_expression::codegen(ir::module *mod) const {
|
||||||
return mod->get_builder().create_reduce(arg_->codegen(mod), axis_->value());
|
return mod->get_builder().create_reduce(arg_->codegen(mod), axis_->value());
|
||||||
|
Reference in New Issue
Block a user