[code generation] fixed bug in on-the-fly AST to IR lowering

This commit is contained in:
Philippe Tillet
2019-01-23 00:11:42 -05:00
parent a0ecdba5a2
commit 7eebdceb6a
10 changed files with 344 additions and 102 deletions

View File

@@ -90,6 +90,11 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
throw std::runtime_error("unreachable");
}
void node::implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty) {
ir::value *tmp = ir::undef_value::get(ty);
implicit_broadcast(mod, arg, tmp);
}
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){
ir::builder &builder = mod->get_builder();
ir::type *lhs_ty = lhs->get_type();
@@ -320,7 +325,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
if(expr_){
value = expr_->codegen(mod);
explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
implicit_broadcast(mod, value, value);
implicit_broadcast(mod, value, ty);
}
value->set_name(name);
mod->set_value(name, value);
@@ -331,85 +336,85 @@ ir::value* initializer::codegen(ir::module * mod) const{
/* Expression */
/*------------------*/
/* Binary operator */
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *arg, ir::value *rhs, const std::string &name) const
{
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
implicit_broadcast(mod, lhs, rhs);
implicit_cast(builder, arg, rhs, is_float, is_ptr, is_int, is_signed);
implicit_broadcast(mod, arg, rhs);
if(op_==MUL && is_float)
return builder.create_fmul(lhs, rhs, name);
return builder.create_fmul(arg, rhs, name);
if(op_==MUL && is_int)
return builder.create_mul(lhs, rhs, name);
return builder.create_mul(arg, rhs, name);
if(op_==DIV && is_float)
return builder.create_fdiv(lhs, rhs, name);
return builder.create_fdiv(arg, rhs, name);
if(op_==DIV && is_int && is_signed)
return builder.create_sdiv(lhs, rhs, name);
return builder.create_sdiv(arg, rhs, name);
if(op_==DIV && is_int && !is_signed)
return builder.create_udiv(lhs, rhs, name);
return builder.create_udiv(arg, rhs, name);
if(op_==MOD && is_float)
return builder.create_frem(lhs, rhs, name);
return builder.create_frem(arg, rhs, name);
if(op_==MOD && is_int && is_signed)
return builder.create_srem(lhs, rhs, name);
return builder.create_srem(arg, rhs, name);
if(op_==MOD && is_int && !is_signed)
return builder.create_urem(lhs, rhs, name);
return builder.create_urem(arg, rhs, name);
if(op_==ADD && is_float)
return builder.create_fadd(lhs, rhs, name);
return builder.create_fadd(arg, rhs, name);
if(op_==ADD && is_int)
return builder.create_add(lhs, rhs);
return builder.create_add(arg, rhs);
if(op_==ADD && is_ptr)
return builder.create_gep(lhs, {rhs});
return builder.create_gep(arg, {rhs});
if(op_==SUB && is_float)
return builder.create_fsub(lhs, rhs, name);
return builder.create_fsub(arg, rhs, name);
if(op_==SUB && is_int)
return builder.create_sub(lhs, rhs, name);
return builder.create_sub(arg, rhs, name);
if(op_==SUB && is_ptr)
return builder.create_gep(lhs, {builder.create_neg(rhs)});
return builder.create_gep(arg, {builder.create_neg(rhs)});
if(op_==LEFT_SHIFT)
return builder.create_shl(lhs, rhs, name);
return builder.create_shl(arg, rhs, name);
if(op_==RIGHT_SHIFT)
return builder.create_ashr(lhs, rhs, name);
return builder.create_ashr(arg, rhs, name);
if(op_ == LT && is_float)
return builder.create_fcmpOLT(lhs, rhs, name);
return builder.create_fcmpOLT(arg, rhs, name);
if(op_ == LT && is_int && is_signed)
return builder.create_icmpSLT(lhs, rhs, name);
return builder.create_icmpSLT(arg, rhs, name);
if(op_ == LT && is_int && !is_signed)
return builder.create_icmpULT(lhs, rhs, name);
return builder.create_icmpULT(arg, rhs, name);
if(op_ == GT && is_float)
return builder.create_fcmpOGT(lhs, rhs, name);
return builder.create_fcmpOGT(arg, rhs, name);
if(op_ == GT && is_int && is_signed)
return builder.create_icmpSGT(lhs, rhs, name);
return builder.create_icmpSGT(arg, rhs, name);
if(op_ == GT && is_int && !is_signed)
return builder.create_icmpUGT(lhs, rhs, name);
return builder.create_icmpUGT(arg, rhs, name);
if(op_ == LE && is_float)
return builder.create_fcmpOLE(lhs, rhs, name);
return builder.create_fcmpOLE(arg, rhs, name);
if(op_ == LE && is_int && is_signed)
return builder.create_icmpSLE(lhs, rhs, name);
return builder.create_icmpSLE(arg, rhs, name);
if(op_ == LE && is_int && !is_signed)
return builder.create_icmpULE(lhs, rhs, name);
return builder.create_icmpULE(arg, rhs, name);
if(op_ == GE && is_float)
return builder.create_fcmpOGE(lhs, rhs, name);
return builder.create_fcmpOGE(arg, rhs, name);
if(op_ == GE && is_int && is_signed)
return builder.create_icmpSGE(lhs, rhs, name);
return builder.create_icmpSGE(arg, rhs, name);
if(op_ == GE && is_int && !is_signed)
return builder.create_icmpUGE(lhs, rhs, name);
return builder.create_icmpUGE(arg, rhs, name);
if(op_ == EQ && is_float)
return builder.create_fcmpOEQ(lhs, rhs, name);
return builder.create_fcmpOEQ(arg, rhs, name);
if(op_ == EQ && is_int)
return builder.create_icmpEQ(lhs, rhs, name);
return builder.create_icmpEQ(arg, rhs, name);
if(op_ == NE && is_float)
return builder.create_fcmpONE(lhs, rhs, name);
return builder.create_fcmpONE(arg, rhs, name);
if(op_ == NE && is_int)
return builder.create_icmpNE(lhs, rhs, name);
return builder.create_icmpNE(arg, rhs, name);
if(op_ == AND)
return builder.create_and(lhs, rhs, name);
return builder.create_and(arg, rhs, name);
if(op_ == XOR)
return builder.create_xor(lhs, rhs, name);
return builder.create_xor(arg, rhs, name);
if(op_ == OR)
return builder.create_or(lhs, rhs, name);
return builder.create_or(arg, rhs, name);
if(op_ == LAND)
return builder.create_and(lhs, rhs, name);
return builder.create_and(arg, rhs, name);
if(op_ == LOR)
return builder.create_or(lhs, rhs, name);
return builder.create_or(arg, rhs, name);
throw std::runtime_error("unreachable");
}
@@ -433,6 +438,12 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
ir::value *A = A_->codegen(mod);
ir::value *B = B_->codegen(mod);
ir::value *C = C_->codegen(mod);
// unsigned M = A->get_type()->get_tile_shapes()[0];
// unsigned N = B->get_type()->get_tile_shapes()[1];
// ir::type *scalar_ty = A->get_type()->get_scalar_ty();
// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N});
// ir::value *tmp = ir::undef_value::get(tile_ty);
// implicit_broadcast(mod, tmp, C);
return mod->get_builder().create_matmul(A, B, C);
}