From 8fdd7e7ed67cc4bcde1dec4ed614f1bd4839ffe7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 13 Sep 2021 17:39:06 -0700 Subject: [PATCH] [LANG] Fixed semantics of boolean load/store (#282) --- lib/ir/dispatch.cc | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 02dbc6ef5..61ce7317e 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -509,11 +509,18 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir: other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); } } + ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); + ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); + // treat bool* as int8* + if(elt_ty == builder->get_int1_ty()){ + elt_ty = builder->get_int8_ty(); + ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); + ptr = dispatch::cast(ptr, ptr_ty, builder); + } if (!mask && !other) return builder->create_load(ptr); if (!mask) throw std::runtime_error("`other` cannot be provided without `mask`"); - ir::type *elt_ty = ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(); auto shape = ptr->get_type()->get_block_shapes(); if(!other){ other = ir::undef_value::get(elt_ty); @@ -530,8 +537,16 @@ ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir:: val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); if(mask) mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - ir::type *ptr_ty = ptr->get_type(); - val = dispatch::cast(val, ptr_ty->get_scalar_ty()->get_pointer_element_ty(), builder); + ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); + ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); + // treat bool* as int8* + if(elt_ty == builder->get_int1_ty()){ + elt_ty = builder->get_int8_ty(); + ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); + ptr = dispatch::cast(ptr, ptr_ty, builder); + } + // cast to target data-type + val = dispatch::cast(val, elt_ty, builder); if (!mask) return builder->create_store(ptr, val); if(!mask->get_type()->get_scalar_ty()->is_bool_ty())