diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 3a4877d36..3dd6b0cfa 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -22,7 +22,7 @@ const char src[] = "\ void test(fp32 *A, fp32 *B, fp32 *C, int32 i){\ int32 tile[16, 16] = 0;\ - int32 test[16, 16] = tile + i;\ + fp32 *test[16, 16] = tile + A;\ i = 1;\ A = A + i;\ }\ diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index 52d7a4a2e..6137e1022 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -56,7 +56,11 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, ir::type *left_ty = lhs->get_type()->get_scalar_ty(); ir::type *right_ty = rhs->get_type()->get_scalar_ty(); // One operand is pointer - if(left_ty->is_pointer_ty()){ + if(left_ty->is_pointer_ty() || right_ty->is_pointer_ty()){ + if(left_ty->is_pointer_ty() && right_ty->is_pointer_ty()) + throw std::runtime_error("invalid operands"); + if(right_ty->is_pointer_ty()) + std::swap(lhs, rhs); is_ptr = true; } // One operand is double diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index a46c4f036..627645a09 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -248,7 +248,7 @@ getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std:: type *getelementptr_inst::get_return_type(type *elt_ty, value *ptr, const std::vector &idx_list) { // result pointer type - type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), ptr->get_type()->get_pointer_address_space()); + type *ptr_ty = pointer_type::get(get_indexed_type(elt_ty, idx_list), ptr->get_type()->get_scalar_ty()->get_pointer_address_space()); // Tile GEP if(ptr->get_type()->is_tile_ty()) return tile_type::get_same_shapes(ptr_ty, ptr->get_type());