Add argmin argmax (#552)
This commit is contained in:
@@ -588,6 +588,45 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
}
|
||||
}
|
||||
|
||||
// layout checkers
|
||||
bool layouts::is_scanline(ir::instruction *i) {
|
||||
return this->get(i->get_operand(0))->to_scanline() != nullptr;
|
||||
}
|
||||
|
||||
bool layouts::is_coalesced_scanline(ir::instruction *i) {
|
||||
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
|
||||
auto *scanline = this->get(i->get_operand(0))->to_scanline();
|
||||
return scanline && scanline->get_order()[0] == red->get_axis();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool layouts::is_mma(ir::instruction *i) {
|
||||
return this->get(i->get_operand(0))->to_mma() != nullptr;
|
||||
}
|
||||
|
||||
bool layouts::is_a100_mma(ir::instruction *i) {
|
||||
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
|
||||
return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) &&
|
||||
(red->get_axis() == 1);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void layouts::create_tmp_layout(size_t id, data_layout *arg,
|
||||
const std::vector<int> &axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
ir::instruction *i, bool is_index) {
|
||||
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
|
||||
: i->get_type()->get_scalar_ty();
|
||||
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_);
|
||||
if (is_index) {
|
||||
tmp_index_[i] = id;
|
||||
} else {
|
||||
tmp_[i] = id;
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
// make graph
|
||||
graph_.clear();
|
||||
@@ -612,22 +651,26 @@ void layouts::run(ir::module &mod) {
|
||||
// std::cout << "layout: " << std::endl;
|
||||
// i->print(std::cout);
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||
id++;
|
||||
ir::value *arg = red->get_operand(0);
|
||||
unsigned axis = red->get_axis();
|
||||
distributed_layout *layout =
|
||||
dynamic_cast<analysis::distributed_layout *>(get(arg));
|
||||
// shape
|
||||
auto shapes = arg->get_type()->get_block_shapes();
|
||||
distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(get(arg));
|
||||
shapes[axis] = layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
|
||||
|
||||
unsigned axis = red->get_axis();
|
||||
shapes[axis] =
|
||||
layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[red] = id;
|
||||
id++;
|
||||
create_tmp_layout(id, layout, axes_->get(arg), shapes, red);
|
||||
|
||||
if (red->with_index()) {
|
||||
id++;
|
||||
create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true);
|
||||
}
|
||||
}
|
||||
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
|
||||
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
|
||||
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
|
||||
id++;
|
||||
size_t dim = val->get_type()->get_tile_rank();
|
||||
ir::type::block_shapes_t shape(dim);
|
||||
for(size_t k = 0; k < dim; k++){
|
||||
@@ -640,13 +683,12 @@ void layouts::run(ir::module &mod) {
|
||||
int out_vec = out_layout->contig_per_thread(out_ord[0]);
|
||||
int pad = std::max(in_vec, out_vec);
|
||||
shape[out_ord[0]] += pad;
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[val] = id;
|
||||
id++;
|
||||
create_tmp_layout(id, out_layout, axes_->get(val), shape, val);
|
||||
}
|
||||
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
|
||||
id++;
|
||||
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[atom] = id;
|
||||
create_tmp_layout(id, nullptr, {}, {1}, atom);
|
||||
}
|
||||
});
|
||||
|
||||
|
Reference in New Issue
Block a user