[LANG] Fixed semantics of boolean load/store (#282)
This commit is contained in:
@@ -509,11 +509,18 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir:
|
|||||||
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
|
||||||
|
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
|
||||||
|
// treat bool* as int8*
|
||||||
|
if(elt_ty == builder->get_int1_ty()){
|
||||||
|
elt_ty = builder->get_int8_ty();
|
||||||
|
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
|
||||||
|
ptr = dispatch::cast(ptr, ptr_ty, builder);
|
||||||
|
}
|
||||||
if (!mask && !other)
|
if (!mask && !other)
|
||||||
return builder->create_load(ptr);
|
return builder->create_load(ptr);
|
||||||
if (!mask)
|
if (!mask)
|
||||||
throw std::runtime_error("`other` cannot be provided without `mask`");
|
throw std::runtime_error("`other` cannot be provided without `mask`");
|
||||||
ir::type *elt_ty = ptr->get_type()->get_scalar_ty()->get_pointer_element_ty();
|
|
||||||
auto shape = ptr->get_type()->get_block_shapes();
|
auto shape = ptr->get_type()->get_block_shapes();
|
||||||
if(!other){
|
if(!other){
|
||||||
other = ir::undef_value::get(elt_ty);
|
other = ir::undef_value::get(elt_ty);
|
||||||
@@ -530,8 +537,16 @@ ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::
|
|||||||
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
|
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
|
||||||
if(mask)
|
if(mask)
|
||||||
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
|
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
|
||||||
ir::type *ptr_ty = ptr->get_type();
|
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
|
||||||
val = dispatch::cast(val, ptr_ty->get_scalar_ty()->get_pointer_element_ty(), builder);
|
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
|
||||||
|
// treat bool* as int8*
|
||||||
|
if(elt_ty == builder->get_int1_ty()){
|
||||||
|
elt_ty = builder->get_int8_ty();
|
||||||
|
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
|
||||||
|
ptr = dispatch::cast(ptr, ptr_ty, builder);
|
||||||
|
}
|
||||||
|
// cast to target data-type
|
||||||
|
val = dispatch::cast(val, elt_ty, builder);
|
||||||
if (!mask)
|
if (!mask)
|
||||||
return builder->create_store(ptr, val);
|
return builder->create_store(ptr, val);
|
||||||
if(!mask->get_type()->get_scalar_ty()->is_bool_ty())
|
if(!mask->get_type()->get_scalar_ty()->is_bool_ty())
|
||||||
|
Reference in New Issue
Block a user