Add argmin argmax (#552)
This commit is contained in:
@@ -308,13 +308,20 @@ private:
|
||||
|
||||
void create(size_t id, const std::vector<ir::value*>& values);
|
||||
|
||||
public:
|
||||
void create_tmp_layout(size_t id, data_layout* arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
ir::instruction* i,
|
||||
bool is_index = false);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
|
||||
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
|
||||
bool has(size_t id) { return layouts_.find(id) != layouts_.end(); }
|
||||
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
|
||||
size_t num_layouts() const { return values_.size();}
|
||||
data_layout* get(size_t id) { return layouts_.at(id); }
|
||||
@@ -322,7 +329,19 @@ public:
|
||||
std::map<size_t, data_layout*> &get_all() { return layouts_; }
|
||||
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
|
||||
int tmp(ir::value* i) { return tmp_.at(i);}
|
||||
int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); }
|
||||
int tmp_index(ir::value* i) { return tmp_index_.at(i);}
|
||||
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
|
||||
|
||||
// layout checkers
|
||||
bool is_scanline(ir::instruction* i);
|
||||
|
||||
bool is_coalesced_scanline(ir::instruction* i);
|
||||
|
||||
bool is_mma(ir::instruction* i);
|
||||
|
||||
bool is_a100_mma(ir::instruction* i);
|
||||
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
|
||||
@@ -336,6 +355,7 @@ private:
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
std::map<size_t, data_layout*> layouts_;
|
||||
std::map<ir::value*, size_t> tmp_;
|
||||
std::map<ir::value*, size_t> tmp_index_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -118,8 +118,15 @@ private:
|
||||
llvm::Attribute cvt(ir::attribute attr);
|
||||
void packed_type(ir::value* i);
|
||||
void forward_declare(ir::function* fn);
|
||||
Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty);
|
||||
|
||||
public:
|
||||
private:
|
||||
typedef std::function<void(
|
||||
std::pair<Value *, Value *> &acc, std::function<Value *()> load_value_fn,
|
||||
std::function<Value *()> load_index_fn, bool is_first)>
|
||||
acc_fn_t;
|
||||
|
||||
public:
|
||||
generator(analysis::axes *a_axes,
|
||||
analysis::layouts *layouts,
|
||||
analysis::align *alignment,
|
||||
@@ -176,9 +183,8 @@ public:
|
||||
void visit_trans_inst(ir::trans_inst*);
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
Value* shfl_sync(Value* acc, int32_t i);
|
||||
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral);
|
||||
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
|
||||
void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
void visit_layout_convert(ir::value *out, ir::value *in);
|
||||
|
@@ -914,7 +914,9 @@ class reduce_inst: public builtin_inst {
|
||||
public:
|
||||
enum op_t{
|
||||
ADD, SUB, MAX, MIN, UMAX, UMIN,
|
||||
ARGMAX, ARGMIN, ARGUMAX, ARGUMIN,
|
||||
FADD, FSUB, FMAX, FMIN,
|
||||
ARGFMAX, ARGFMIN,
|
||||
XOR
|
||||
};
|
||||
|
||||
@@ -932,12 +934,19 @@ public:
|
||||
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
op_t get_op() const { return op_; }
|
||||
bool with_index() const {
|
||||
return with_index_ops_.find(op_) != with_index_ops_.end();
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
op_t op_;
|
||||
const static inline std::set<op_t> with_index_ops_ = {
|
||||
op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX,
|
||||
op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN};
|
||||
unsigned axis_;
|
||||
op_t op_;
|
||||
};
|
||||
|
||||
|
||||
class select_inst: public builtin_inst {
|
||||
private:
|
||||
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
|
||||
|
Reference in New Issue
Block a user