[LANG] Fixed semantics of boolean load/store (#282)

This commit is contained in:
Philippe Tillet
2021-09-13 17:39:06 -07:00
committed by GitHub
parent 3e395bc84e
commit 8fdd7e7ed6

View File

@@ -509,11 +509,18 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir:
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
}
}
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
// treat bool* as int8*
if(elt_ty == builder->get_int1_ty()){
elt_ty = builder->get_int8_ty();
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
if (!mask && !other)
return builder->create_load(ptr);
if (!mask)
throw std::runtime_error("`other` cannot be provided without `mask`");
ir::type *elt_ty = ptr->get_type()->get_scalar_ty()->get_pointer_element_ty();
auto shape = ptr->get_type()->get_block_shapes();
if(!other){
other = ir::undef_value::get(elt_ty);
@@ -530,8 +537,16 @@ ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
if(mask)
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
ir::type *ptr_ty = ptr->get_type();
val = dispatch::cast(val, ptr_ty->get_scalar_ty()->get_pointer_element_ty(), builder);
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
// treat bool* as int8*
if(elt_ty == builder->get_int1_ty()){
elt_ty = builder->get_int8_ty();
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
// cast to target data-type
val = dispatch::cast(val, elt_ty, builder);
if (!mask)
return builder->create_store(ptr, val);
if(!mask->get_type()->get_scalar_ty()->is_bool_ty())