[codegen] more cleaning

This commit is contained in:
Philippe Tillet
2019-10-10 15:52:03 -04:00
parent a3f76b6eb1
commit 4efd0a3c6b
9 changed files with 148 additions and 144 deletions

View File

@@ -89,8 +89,8 @@ bool liveness::do_pad(ir::value *x) {
ir::value *b = dot->get_operand(1);
size_t a_previous = pad_[a];
size_t b_previous = pad_[b];
auto a_order = layouts_->get(a).order;
auto b_order = layouts_->get(b).order;
auto a_order = layouts_->get(a)->order;
auto b_order = layouts_->get(b)->order;
bool a_row = is_trans(a) ^ (a_order[0] == 1);
bool b_row = is_trans(b) ^ (b_order[0] == 1);
auto a_shapes = a->get_type()->get_tile_shapes();
@@ -108,9 +108,9 @@ bool liveness::do_pad(ir::value *x) {
}
// padding for copy to shared
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
auto cts_order = layouts_->get(cts).order;
auto cts_order = layouts_->get(cts)->order;
ir::value *arg = cts->get_operand(0);
auto arg_order = layouts_->get(arg).order;
auto arg_order = layouts_->get(arg)->order;
size_t previous = pad_[cts];
if(cts_order != arg_order)
pad_[cts] = std::max<int>(pad_[cts], 4);
@@ -134,26 +134,10 @@ bool liveness::do_pad(ir::value *x) {
}
unsigned liveness::num_bytes(ir::value *x) {
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t axis = red->get_axis();
ir::value *op = red->get_operand(0);
auto shapes = op->get_type()->get_tile_shapes();
shapes.erase(shapes.begin() + axis);
size_t num_elements = 1;
for(auto x: shapes)
num_elements *= x;
size_t depth;
if(layouts_->get(x).type == HMMA_884)
depth = layouts_->get(op).wpt.at(axis);
else
depth = layouts_->get(op).mts.at(axis);
return num_elements * num_bytes * depth;
}
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = pad_.at(x);
if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[layouts_->get(x).order[0]];
unsigned ld = x->get_type()->get_tile_shapes()[layouts_->get(x)->order[0]];
num_bytes += pad * num_bytes / ld;
}
if(has_double(x))