Add argmin argmax (#552)

This commit is contained in:
Keren Zhou
2022-06-15 13:55:20 -07:00
committed by GitHub
parent 6b9756532f
commit b5e728cb14
11 changed files with 345 additions and 101 deletions

View File

@@ -308,13 +308,20 @@ private:
void create(size_t id, const std::vector<ir::value*>& values);
public:
void create_tmp_layout(size_t id, data_layout* arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
ir::instruction* i,
bool is_index = false);
public:
// constructor
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
// accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
bool has(size_t id) { return layouts_.find(id) != layouts_.end(); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); }
@@ -322,7 +329,19 @@ public:
std::map<size_t, data_layout*> &get_all() { return layouts_; }
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
int tmp(ir::value* i) { return tmp_.at(i);}
int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); }
int tmp_index(ir::value* i) { return tmp_index_.at(i);}
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
// layout checkers
bool is_scanline(ir::instruction* i);
bool is_coalesced_scanline(ir::instruction* i);
bool is_mma(ir::instruction* i);
bool is_a100_mma(ir::instruction* i);
// execution
void run(ir::module &mod);
@@ -336,6 +355,7 @@ private:
std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_;
std::map<ir::value*, size_t> tmp_index_;
};
}

View File

@@ -118,8 +118,15 @@ private:
llvm::Attribute cvt(ir::attribute attr);
void packed_type(ir::value* i);
void forward_declare(ir::function* fn);
Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty);
public:
private:
typedef std::function<void(
std::pair<Value *, Value *> &acc, std::function<Value *()> load_value_fn,
std::function<Value *()> load_index_fn, bool is_first)>
acc_fn_t;
public:
generator(analysis::axes *a_axes,
analysis::layouts *layouts,
analysis::align *alignment,
@@ -176,9 +183,8 @@ public:
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
Value* shfl_sync(Value* acc, int32_t i);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_layout_convert(ir::value *out, ir::value *in);

View File

@@ -914,7 +914,9 @@ class reduce_inst: public builtin_inst {
public:
enum op_t{
ADD, SUB, MAX, MIN, UMAX, UMIN,
ARGMAX, ARGMIN, ARGUMAX, ARGUMIN,
FADD, FSUB, FMAX, FMIN,
ARGFMAX, ARGFMIN,
XOR
};
@@ -932,12 +934,19 @@ public:
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
op_t get_op() const { return op_; }
bool with_index() const {
return with_index_ops_.find(op_) != with_index_ops_.end();
}
private:
unsigned axis_;
op_t op_;
const static inline std::set<op_t> with_index_ops_ = {
op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX,
op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN};
unsigned axis_;
op_t op_;
};
class select_inst: public builtin_inst {
private:
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);