[code generation] fixed bug in on-the-fly AST to IR lowering
This commit is contained in:
@@ -90,6 +90,11 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
void node::implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty) {
|
||||
ir::value *tmp = ir::undef_value::get(ty);
|
||||
implicit_broadcast(mod, arg, tmp);
|
||||
}
|
||||
|
||||
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::type *lhs_ty = lhs->get_type();
|
||||
@@ -320,7 +325,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
if(expr_){
|
||||
value = expr_->codegen(mod);
|
||||
explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
|
||||
implicit_broadcast(mod, value, value);
|
||||
implicit_broadcast(mod, value, ty);
|
||||
}
|
||||
value->set_name(name);
|
||||
mod->set_value(name, value);
|
||||
@@ -331,85 +336,85 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
/* Expression */
|
||||
/*------------------*/
|
||||
/* Binary operator */
|
||||
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
|
||||
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *arg, ir::value *rhs, const std::string &name) const
|
||||
{
|
||||
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
|
||||
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
|
||||
implicit_broadcast(mod, lhs, rhs);
|
||||
implicit_cast(builder, arg, rhs, is_float, is_ptr, is_int, is_signed);
|
||||
implicit_broadcast(mod, arg, rhs);
|
||||
if(op_==MUL && is_float)
|
||||
return builder.create_fmul(lhs, rhs, name);
|
||||
return builder.create_fmul(arg, rhs, name);
|
||||
if(op_==MUL && is_int)
|
||||
return builder.create_mul(lhs, rhs, name);
|
||||
return builder.create_mul(arg, rhs, name);
|
||||
if(op_==DIV && is_float)
|
||||
return builder.create_fdiv(lhs, rhs, name);
|
||||
return builder.create_fdiv(arg, rhs, name);
|
||||
if(op_==DIV && is_int && is_signed)
|
||||
return builder.create_sdiv(lhs, rhs, name);
|
||||
return builder.create_sdiv(arg, rhs, name);
|
||||
if(op_==DIV && is_int && !is_signed)
|
||||
return builder.create_udiv(lhs, rhs, name);
|
||||
return builder.create_udiv(arg, rhs, name);
|
||||
if(op_==MOD && is_float)
|
||||
return builder.create_frem(lhs, rhs, name);
|
||||
return builder.create_frem(arg, rhs, name);
|
||||
if(op_==MOD && is_int && is_signed)
|
||||
return builder.create_srem(lhs, rhs, name);
|
||||
return builder.create_srem(arg, rhs, name);
|
||||
if(op_==MOD && is_int && !is_signed)
|
||||
return builder.create_urem(lhs, rhs, name);
|
||||
return builder.create_urem(arg, rhs, name);
|
||||
if(op_==ADD && is_float)
|
||||
return builder.create_fadd(lhs, rhs, name);
|
||||
return builder.create_fadd(arg, rhs, name);
|
||||
if(op_==ADD && is_int)
|
||||
return builder.create_add(lhs, rhs);
|
||||
return builder.create_add(arg, rhs);
|
||||
if(op_==ADD && is_ptr)
|
||||
return builder.create_gep(lhs, {rhs});
|
||||
return builder.create_gep(arg, {rhs});
|
||||
if(op_==SUB && is_float)
|
||||
return builder.create_fsub(lhs, rhs, name);
|
||||
return builder.create_fsub(arg, rhs, name);
|
||||
if(op_==SUB && is_int)
|
||||
return builder.create_sub(lhs, rhs, name);
|
||||
return builder.create_sub(arg, rhs, name);
|
||||
if(op_==SUB && is_ptr)
|
||||
return builder.create_gep(lhs, {builder.create_neg(rhs)});
|
||||
return builder.create_gep(arg, {builder.create_neg(rhs)});
|
||||
if(op_==LEFT_SHIFT)
|
||||
return builder.create_shl(lhs, rhs, name);
|
||||
return builder.create_shl(arg, rhs, name);
|
||||
if(op_==RIGHT_SHIFT)
|
||||
return builder.create_ashr(lhs, rhs, name);
|
||||
return builder.create_ashr(arg, rhs, name);
|
||||
if(op_ == LT && is_float)
|
||||
return builder.create_fcmpOLT(lhs, rhs, name);
|
||||
return builder.create_fcmpOLT(arg, rhs, name);
|
||||
if(op_ == LT && is_int && is_signed)
|
||||
return builder.create_icmpSLT(lhs, rhs, name);
|
||||
return builder.create_icmpSLT(arg, rhs, name);
|
||||
if(op_ == LT && is_int && !is_signed)
|
||||
return builder.create_icmpULT(lhs, rhs, name);
|
||||
return builder.create_icmpULT(arg, rhs, name);
|
||||
if(op_ == GT && is_float)
|
||||
return builder.create_fcmpOGT(lhs, rhs, name);
|
||||
return builder.create_fcmpOGT(arg, rhs, name);
|
||||
if(op_ == GT && is_int && is_signed)
|
||||
return builder.create_icmpSGT(lhs, rhs, name);
|
||||
return builder.create_icmpSGT(arg, rhs, name);
|
||||
if(op_ == GT && is_int && !is_signed)
|
||||
return builder.create_icmpUGT(lhs, rhs, name);
|
||||
return builder.create_icmpUGT(arg, rhs, name);
|
||||
if(op_ == LE && is_float)
|
||||
return builder.create_fcmpOLE(lhs, rhs, name);
|
||||
return builder.create_fcmpOLE(arg, rhs, name);
|
||||
if(op_ == LE && is_int && is_signed)
|
||||
return builder.create_icmpSLE(lhs, rhs, name);
|
||||
return builder.create_icmpSLE(arg, rhs, name);
|
||||
if(op_ == LE && is_int && !is_signed)
|
||||
return builder.create_icmpULE(lhs, rhs, name);
|
||||
return builder.create_icmpULE(arg, rhs, name);
|
||||
if(op_ == GE && is_float)
|
||||
return builder.create_fcmpOGE(lhs, rhs, name);
|
||||
return builder.create_fcmpOGE(arg, rhs, name);
|
||||
if(op_ == GE && is_int && is_signed)
|
||||
return builder.create_icmpSGE(lhs, rhs, name);
|
||||
return builder.create_icmpSGE(arg, rhs, name);
|
||||
if(op_ == GE && is_int && !is_signed)
|
||||
return builder.create_icmpUGE(lhs, rhs, name);
|
||||
return builder.create_icmpUGE(arg, rhs, name);
|
||||
if(op_ == EQ && is_float)
|
||||
return builder.create_fcmpOEQ(lhs, rhs, name);
|
||||
return builder.create_fcmpOEQ(arg, rhs, name);
|
||||
if(op_ == EQ && is_int)
|
||||
return builder.create_icmpEQ(lhs, rhs, name);
|
||||
return builder.create_icmpEQ(arg, rhs, name);
|
||||
if(op_ == NE && is_float)
|
||||
return builder.create_fcmpONE(lhs, rhs, name);
|
||||
return builder.create_fcmpONE(arg, rhs, name);
|
||||
if(op_ == NE && is_int)
|
||||
return builder.create_icmpNE(lhs, rhs, name);
|
||||
return builder.create_icmpNE(arg, rhs, name);
|
||||
if(op_ == AND)
|
||||
return builder.create_and(lhs, rhs, name);
|
||||
return builder.create_and(arg, rhs, name);
|
||||
if(op_ == XOR)
|
||||
return builder.create_xor(lhs, rhs, name);
|
||||
return builder.create_xor(arg, rhs, name);
|
||||
if(op_ == OR)
|
||||
return builder.create_or(lhs, rhs, name);
|
||||
return builder.create_or(arg, rhs, name);
|
||||
if(op_ == LAND)
|
||||
return builder.create_and(lhs, rhs, name);
|
||||
return builder.create_and(arg, rhs, name);
|
||||
if(op_ == LOR)
|
||||
return builder.create_or(lhs, rhs, name);
|
||||
return builder.create_or(arg, rhs, name);
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
@@ -433,6 +438,12 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
|
||||
ir::value *A = A_->codegen(mod);
|
||||
ir::value *B = B_->codegen(mod);
|
||||
ir::value *C = C_->codegen(mod);
|
||||
// unsigned M = A->get_type()->get_tile_shapes()[0];
|
||||
// unsigned N = B->get_type()->get_tile_shapes()[1];
|
||||
// ir::type *scalar_ty = A->get_type()->get_scalar_ty();
|
||||
// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N});
|
||||
// ir::value *tmp = ir::undef_value::get(tile_ty);
|
||||
// implicit_broadcast(mod, tmp, C);
|
||||
return mod->get_builder().create_matmul(A, B, C);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user