[CORE] Fixed several issues that arose in the development of the
torch-blocksparse package: * Now using warp shuffle in reductions when possible * Various bugfixes in layout inference * Added INFINITY, exponential and select * Better error messages for unimplemented constructs
This commit is contained in:
committed by
Philippe Tillet
parent
ac26fbdc1f
commit
3304629de9
@@ -26,8 +26,6 @@ inline bool is_shmem_res(ir::value* v){
|
||||
return false;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_REDUCE)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||
return true;
|
||||
return false;
|
||||
@@ -76,8 +74,9 @@ void cts::run(ir::module &mod) {
|
||||
size_t num_op = i->get_num_operands();
|
||||
// copy to shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(is_shmem_op(i, k))
|
||||
if(is_shmem_op(i, k)){
|
||||
add_copy(i, i->get_operand(k), builder, true);
|
||||
}
|
||||
// copy from shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(!dynamic_cast<ir::phi_node*>(i) &&
|
||||
|
@@ -83,6 +83,19 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
}
|
||||
}
|
||||
|
||||
bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
|
||||
auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
|
||||
if(cfs) {
|
||||
ir::value *arg = cfs->get_operand(0);
|
||||
ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
|
||||
if(!cts)
|
||||
return false;
|
||||
cfs->replace_all_uses_with(cts->get_operand(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||
if(!x)
|
||||
@@ -183,6 +196,7 @@ void peephole::run(ir::module &mod) {
|
||||
continue;
|
||||
bool was_modified = false;
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
|
Reference in New Issue
Block a user