diff --git a/include/triton/lang/code_gen.h b/include/triton/lang/code_gen.h index 5347421bb..531da19c5 100644 --- a/include/triton/lang/code_gen.h +++ b/include/triton/lang/code_gen.h @@ -48,6 +48,7 @@ protected: void set_ret(ir::value* value); ir::value *GenUnaryMinus(ir::value* arg); + ir::value *GenUnaryInc(UnaryOp* arg, bool is_postfix, bool is_inc); public: Generator(Parser* parser) : parser_(parser) {} diff --git a/lib/ir/constant.cc b/lib/ir/constant.cc index 825a130f4..3a7aabf9e 100644 --- a/lib/ir/constant.cc +++ b/lib/ir/constant.cc @@ -43,6 +43,8 @@ constant_int::constant_int(type *ty, uint64_t value) : constant(ty, 0), value_(value){ } constant_int *constant_int::get(type *ty, uint64_t value) { + if (!ty->is_integer_ty()) + throw std::runtime_error("Cannot create constant_int with non integer ty"); context_impl *impl = ty->get_context().p_impl.get(); constant_int *& cst = impl->int_constants_[std::make_pair(ty, value)]; if(cst == nullptr) diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 2d72e1794..6de85ea25 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -182,6 +182,46 @@ ir::value* Generator::GenUnaryMinus(ir::value* arg) { return bld_->create_sub(_0, arg); } +ir::value* Generator::GenUnaryInc(UnaryOp* expr, bool is_postfix, + bool is_inc) { + Visit(expr->operand_); + ir::value* arg = ret_; + + ir::value *_1 = nullptr; + ir::value *instr = nullptr; + + if (arg->get_type()->is_floating_point_ty()) { + _1 = ir::constant_fp::get(arg->get_type(), 1.0); + if (is_inc) + instr = bld_->create_fadd(arg, _1); + else + instr = bld_->create_fsub(arg, _1); + } else if (arg->get_type()->is_integer_ty()) { + _1 = ir::constant_int::get(arg->get_type(), 1); + if (is_inc) + instr = bld_->create_add(arg, _1); + else + instr = bld_->create_sub(arg, _1); + } else if (arg->get_type()->is_pointer_ty()) { + ir::type *ty = ir::type::get_int64_ty(*ctx_); + _1 = ir::constant_int::get(ty, 1); + if (is_inc) + instr = bld_->create_gep(arg, {_1}); + else { + ir::value *neg_1 = ir::constant_int::get(ty, -1); + instr = bld_->create_gep(arg, {neg_1}); + } + } else + error_not_implemented("data type not supported for unary inc"); + + mod_->set_value(arg->get_name(), instr); + + if (is_postfix) + return arg; + else + return instr; +} + void Generator::VisitUnaryOp(UnaryOp* unary) { // recursion Visit(unary->operand_); @@ -190,10 +230,10 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { ir::type *arg_scal_ty = arg_ty->get_scalar_ty(); // return switch (unary->op_) { - case Token::PREFIX_INC: return error_not_implemented("prefix increment not implemented"); - case Token::PREFIX_DEC: return error_not_implemented("prefix decrement not implemented"); - case Token::POSTFIX_INC: return error_not_implemented("postfix increment not implemented"); - case Token::POSTFIX_DEC: return error_not_implemented("postfix decrement not implemented"); + case Token::PREFIX_INC: return set_ret(GenUnaryInc(unary, false, true)); + case Token::PREFIX_DEC: return set_ret(GenUnaryInc(unary, false, false)); + case Token::POSTFIX_INC: return set_ret(GenUnaryInc(unary, true, true)); + case Token::POSTFIX_DEC: return set_ret(GenUnaryInc(unary, true, false)); case Token::ADDR: return error_not_implemented("unary & not implemented"); case Token::DEREF: return set_ret(bld_->create_load(arg)); case Token::PLUS: return error_not_implemented("unary + not implemented");