Add argmin argmax (#552)
This commit is contained in:
@@ -308,13 +308,20 @@ private:
|
||||
|
||||
void create(size_t id, const std::vector<ir::value*>& values);
|
||||
|
||||
public:
|
||||
void create_tmp_layout(size_t id, data_layout* arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
ir::instruction* i,
|
||||
bool is_index = false);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
|
||||
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
|
||||
bool has(size_t id) { return layouts_.find(id) != layouts_.end(); }
|
||||
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
|
||||
size_t num_layouts() const { return values_.size();}
|
||||
data_layout* get(size_t id) { return layouts_.at(id); }
|
||||
@@ -322,7 +329,19 @@ public:
|
||||
std::map<size_t, data_layout*> &get_all() { return layouts_; }
|
||||
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
|
||||
int tmp(ir::value* i) { return tmp_.at(i);}
|
||||
int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); }
|
||||
int tmp_index(ir::value* i) { return tmp_index_.at(i);}
|
||||
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
|
||||
|
||||
// layout checkers
|
||||
bool is_scanline(ir::instruction* i);
|
||||
|
||||
bool is_coalesced_scanline(ir::instruction* i);
|
||||
|
||||
bool is_mma(ir::instruction* i);
|
||||
|
||||
bool is_a100_mma(ir::instruction* i);
|
||||
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
|
||||
@@ -336,6 +355,7 @@ private:
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
std::map<size_t, data_layout*> layouts_;
|
||||
std::map<ir::value*, size_t> tmp_;
|
||||
std::map<ir::value*, size_t> tmp_index_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -118,8 +118,15 @@ private:
|
||||
llvm::Attribute cvt(ir::attribute attr);
|
||||
void packed_type(ir::value* i);
|
||||
void forward_declare(ir::function* fn);
|
||||
Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty);
|
||||
|
||||
public:
|
||||
private:
|
||||
typedef std::function<void(
|
||||
std::pair<Value *, Value *> &acc, std::function<Value *()> load_value_fn,
|
||||
std::function<Value *()> load_index_fn, bool is_first)>
|
||||
acc_fn_t;
|
||||
|
||||
public:
|
||||
generator(analysis::axes *a_axes,
|
||||
analysis::layouts *layouts,
|
||||
analysis::align *alignment,
|
||||
@@ -176,9 +183,8 @@ public:
|
||||
void visit_trans_inst(ir::trans_inst*);
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
Value* shfl_sync(Value* acc, int32_t i);
|
||||
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral);
|
||||
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
|
||||
void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
void visit_layout_convert(ir::value *out, ir::value *in);
|
||||
|
@@ -914,7 +914,9 @@ class reduce_inst: public builtin_inst {
|
||||
public:
|
||||
enum op_t{
|
||||
ADD, SUB, MAX, MIN, UMAX, UMIN,
|
||||
ARGMAX, ARGMIN, ARGUMAX, ARGUMIN,
|
||||
FADD, FSUB, FMAX, FMIN,
|
||||
ARGFMAX, ARGFMIN,
|
||||
XOR
|
||||
};
|
||||
|
||||
@@ -932,12 +934,19 @@ public:
|
||||
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
op_t get_op() const { return op_; }
|
||||
bool with_index() const {
|
||||
return with_index_ops_.find(op_) != with_index_ops_.end();
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
op_t op_;
|
||||
const static inline std::set<op_t> with_index_ops_ = {
|
||||
op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX,
|
||||
op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN};
|
||||
unsigned axis_;
|
||||
op_t op_;
|
||||
};
|
||||
|
||||
|
||||
class select_inst: public builtin_inst {
|
||||
private:
|
||||
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
|
||||
|
@@ -588,6 +588,45 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
}
|
||||
}
|
||||
|
||||
// layout checkers
|
||||
bool layouts::is_scanline(ir::instruction *i) {
|
||||
return this->get(i->get_operand(0))->to_scanline() != nullptr;
|
||||
}
|
||||
|
||||
bool layouts::is_coalesced_scanline(ir::instruction *i) {
|
||||
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
|
||||
auto *scanline = this->get(i->get_operand(0))->to_scanline();
|
||||
return scanline && scanline->get_order()[0] == red->get_axis();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool layouts::is_mma(ir::instruction *i) {
|
||||
return this->get(i->get_operand(0))->to_mma() != nullptr;
|
||||
}
|
||||
|
||||
bool layouts::is_a100_mma(ir::instruction *i) {
|
||||
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
|
||||
return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) &&
|
||||
(red->get_axis() == 1);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void layouts::create_tmp_layout(size_t id, data_layout *arg,
|
||||
const std::vector<int> &axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
ir::instruction *i, bool is_index) {
|
||||
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
|
||||
: i->get_type()->get_scalar_ty();
|
||||
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_);
|
||||
if (is_index) {
|
||||
tmp_index_[i] = id;
|
||||
} else {
|
||||
tmp_[i] = id;
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
// make graph
|
||||
graph_.clear();
|
||||
@@ -612,22 +651,26 @@ void layouts::run(ir::module &mod) {
|
||||
// std::cout << "layout: " << std::endl;
|
||||
// i->print(std::cout);
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||
id++;
|
||||
ir::value *arg = red->get_operand(0);
|
||||
unsigned axis = red->get_axis();
|
||||
distributed_layout *layout =
|
||||
dynamic_cast<analysis::distributed_layout *>(get(arg));
|
||||
// shape
|
||||
auto shapes = arg->get_type()->get_block_shapes();
|
||||
distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(get(arg));
|
||||
shapes[axis] = layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
|
||||
|
||||
unsigned axis = red->get_axis();
|
||||
shapes[axis] =
|
||||
layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[red] = id;
|
||||
id++;
|
||||
create_tmp_layout(id, layout, axes_->get(arg), shapes, red);
|
||||
|
||||
if (red->with_index()) {
|
||||
id++;
|
||||
create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true);
|
||||
}
|
||||
}
|
||||
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
|
||||
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
|
||||
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
|
||||
id++;
|
||||
size_t dim = val->get_type()->get_tile_rank();
|
||||
ir::type::block_shapes_t shape(dim);
|
||||
for(size_t k = 0; k < dim; k++){
|
||||
@@ -640,13 +683,12 @@ void layouts::run(ir::module &mod) {
|
||||
int out_vec = out_layout->contig_per_thread(out_ord[0]);
|
||||
int pad = std::max(in_vec, out_vec);
|
||||
shape[out_ord[0]] += pad;
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[val] = id;
|
||||
id++;
|
||||
create_tmp_layout(id, out_layout, axes_->get(val), shape, val);
|
||||
}
|
||||
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
|
||||
id++;
|
||||
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[atom] = id;
|
||||
create_tmp_layout(id, nullptr, {}, {1}, atom);
|
||||
}
|
||||
});
|
||||
|
||||
|
@@ -112,6 +112,8 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
|
||||
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
|
||||
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
|
||||
#define fcmp_oge(...) builder_->CreateFCmpOGE(__VA_ARGS__)
|
||||
#define fcmp_ole(...) builder_->CreateFCmpOLE(__VA_ARGS__)
|
||||
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
|
||||
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
|
||||
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
|
||||
@@ -2334,15 +2336,15 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
||||
/**
|
||||
* \brief Code Generation for `reduce` (ND case)
|
||||
*/
|
||||
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral){
|
||||
//
|
||||
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral){
|
||||
ir::value *arg = x->get_operand(0);
|
||||
const auto with_index = x->with_index();
|
||||
unsigned axis = x->get_axis();
|
||||
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
|
||||
std::vector<unsigned> shapes = layout->get_shape();
|
||||
const auto &shapes = layout->get_shape();
|
||||
|
||||
Type* sca_ty = cvt(arg->get_type()->get_scalar_ty());
|
||||
size_t n_bits = sca_ty->getPrimitiveSizeInBits();
|
||||
|
||||
std::string n_bits_str = std::to_string(n_bits);
|
||||
std::string cst = (n_bits == 64) ? "l" : "r";
|
||||
|
||||
@@ -2351,6 +2353,15 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
|
||||
FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false);
|
||||
InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true);
|
||||
|
||||
Type *index_ty = IntegerType::get(*ctx_, 32);
|
||||
FunctionType *st_shared_index_ty =
|
||||
FunctionType::get(void_ty, {i1_ty, ptr_ty(index_ty, 3), index_ty}, false);
|
||||
InlineAsm *st_shared_index = InlineAsm::get(
|
||||
st_shared_index_ty, "@$0 st.shared.b32 [$1], $2;", "b,r,r", true);
|
||||
FunctionType *ld_shared_index_ty =
|
||||
FunctionType::get(index_ty, {i1_ty, ptr_ty(index_ty, 3)}, false);
|
||||
InlineAsm *ld_shared_index = InlineAsm::get(
|
||||
ld_shared_index_ty, "@$1 ld.shared.b32 $0, [$2];", "=r,b,r", true);
|
||||
|
||||
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value* warp = udiv(thread, i32(32));
|
||||
@@ -2362,54 +2373,64 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
|
||||
std::vector<indices_t> arg_idxs = idxs_.at(arg);
|
||||
size_t n_elts = arg_idxs.size();
|
||||
unsigned col_per_thread = 0;
|
||||
Value* warp_i;
|
||||
Value* warp_j;
|
||||
if(analysis::scanline_layout* scanline = layout->to_scanline()){
|
||||
Value* warp_j = nullptr;
|
||||
if (analysis::scanline_layout *scanline = layout->to_scanline()) {
|
||||
std::vector<int> order = layout->get_order();
|
||||
unsigned mts = scanline->mts(order[0]);
|
||||
shuffle_width = std::min<int>(mts, 32);
|
||||
warps_per_inner = std::max<int>(mts/32, 1);
|
||||
warps_per_inner = std::max<int>(mts / 32, 1);
|
||||
col_per_thread = shapes[order[0]] / mts;
|
||||
warp_i = udiv(warp, i32(warps_per_inner));
|
||||
warp_j = urem(warp, i32(warps_per_inner));
|
||||
}
|
||||
else if(layout->to_mma()){
|
||||
} else if (layout->to_mma()) {
|
||||
shuffle_width = 4;
|
||||
warps_per_inner = layout->to_mma()->wpt(1);
|
||||
col_per_thread = 16;
|
||||
warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id;
|
||||
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
|
||||
}
|
||||
assert(warp_j != nullptr);
|
||||
|
||||
// unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]);
|
||||
//
|
||||
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
|
||||
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
|
||||
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
|
||||
Value *base = cast_shared_layout_ptr(layouts_->get(layouts_->tmp(x)),
|
||||
cvt(x->get_type()->get_scalar_ty()));
|
||||
Value *index_base =
|
||||
with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
|
||||
IntegerType::get(*ctx_, 32))
|
||||
: nullptr;
|
||||
|
||||
// preds
|
||||
Value* is_lane0 = icmp_eq(lane, i32(0));
|
||||
Value* is_warp0 = icmp_eq(warp, i32(0));
|
||||
Value* is_thread0 = icmp_eq(thread, i32(0));
|
||||
Value* lane_j = urem(lane, i32(shuffle_width));
|
||||
Value* first_lane_in_col = icmp_eq(lane_j, i32(0));
|
||||
add_barrier();
|
||||
// compute partial sum for each warp, and store to shared memory
|
||||
for(size_t i = 0; i < n_elts/col_per_thread; i++){
|
||||
Value* acc;
|
||||
std::pair<Value*, Value*> acc;
|
||||
// reduce within thread
|
||||
for(size_t j = 0; j < col_per_thread; j++){
|
||||
Value* val = arg_vals[arg_idxs[i*col_per_thread + j]];
|
||||
// acc = (j == 0) ? val : do_acc(acc, val);
|
||||
acc = (j == 0) ? val : do_acc(acc, val);
|
||||
auto arg_idx = arg_idxs[i*col_per_thread + j];
|
||||
bool is_first = j == 0;
|
||||
do_acc(
|
||||
acc, [&]() -> Value * { return arg_vals[arg_idx]; },
|
||||
[&]() -> Value * { return arg_idx[axis]; }, is_first);
|
||||
}
|
||||
|
||||
// reduce within warp
|
||||
for(int k = shuffle_width/2 ; k > 0; k >>= 1)
|
||||
acc = do_acc(acc, shfl_sync(acc, k));
|
||||
for(int k = shuffle_width/2 ; k > 0; k >>= 1) {
|
||||
do_acc(
|
||||
acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
|
||||
[&]() -> Value * { return shfl_sync(acc.second, k); }, false);
|
||||
}
|
||||
// store partial result to shared memory
|
||||
auto x_idxs = idxs_[x][i];
|
||||
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
|
||||
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
|
||||
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc});
|
||||
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
|
||||
if (with_index) {
|
||||
call(st_shared_index,
|
||||
{icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
|
||||
}
|
||||
}
|
||||
add_barrier();
|
||||
// at this point, partial accumulator synchronized in shared memory
|
||||
@@ -2418,48 +2439,66 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
|
||||
auto x_idxs = idxs_[x][i];
|
||||
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
|
||||
Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner)));
|
||||
Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
|
||||
for(int k = warps_per_inner/2; k > 0; k >>= 1)
|
||||
acc = do_acc(acc, shfl_sync(acc, k));
|
||||
vals_[x][idxs_[x][i]] = acc;
|
||||
std::pair<Value*, Value*> acc;
|
||||
acc.first = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
|
||||
acc.second = with_index ? call(ld_shared_index, {builder_->getInt1(true),
|
||||
gep(index_base, ld_off)})
|
||||
: nullptr;
|
||||
for (int k = warps_per_inner / 2; k > 0; k >>= 1) {
|
||||
do_acc(
|
||||
acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
|
||||
[&]() -> Value * { return shfl_sync(acc.second, k); }, false);
|
||||
}
|
||||
vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
|
||||
}
|
||||
// add_barrier();
|
||||
}
|
||||
|
||||
void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
|
||||
|
||||
void generator::visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral) {
|
||||
ir::value *arg = x->get_operand(0);
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
unsigned axis = x->get_axis();
|
||||
auto with_index = x->with_index();
|
||||
|
||||
// reduce within thread
|
||||
std::map<indices_t, Value*> accs;
|
||||
// index-><current reduced value, current min/max index (optional)>
|
||||
std::map<indices_t, std::pair<Value*, Value*>> accs;
|
||||
for(indices_t idx: idxs_.at(arg)){
|
||||
indices_t pidx = idx;
|
||||
pidx[axis] = i32(0);
|
||||
Value *current = vals_[arg][idx];
|
||||
bool is_first = accs.find(pidx) == accs.end();
|
||||
accs[pidx] = is_first ? current : do_acc(accs[pidx], current);
|
||||
do_acc(
|
||||
accs[pidx], [&]() -> Value * { return vals_[arg][idx]; },
|
||||
[&]() -> Value * { return idx[axis]; }, is_first);
|
||||
};
|
||||
|
||||
// reduce within blocks
|
||||
analysis::data_layout* layout = layouts_->get(layouts_->tmp(x));
|
||||
Value *base = shared_ptr_.at(layout);
|
||||
auto shape = layout->get_shape();
|
||||
auto order = layout->get_order();
|
||||
int space = base->getType()->getPointerAddressSpace();
|
||||
Value *ptr = bit_cast(base, ptr_ty(ty, space));
|
||||
auto *data_layout = layouts_->get(layouts_->tmp(x));
|
||||
auto *data_ptr =
|
||||
cast_shared_layout_ptr(data_layout, cvt(x->get_type()->get_scalar_ty()));
|
||||
auto *index_ptr =
|
||||
with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
|
||||
IntegerType::get(*ctx_, 32))
|
||||
: data_ptr;
|
||||
|
||||
auto shape = data_layout->get_shape();
|
||||
auto order = data_layout->get_order();
|
||||
Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id;
|
||||
for(auto& x: accs) {
|
||||
// current element being computed
|
||||
Value *&acc = x.second;
|
||||
std::pair<Value *, Value *> acc = x.second;
|
||||
indices_t write_idx = x.first;
|
||||
write_idx[axis] = lane;
|
||||
// shared memory write pointer
|
||||
Value *write_off = shared_off(shape, order, write_idx);
|
||||
Value *write_ptr = gep(ptr, write_off);
|
||||
Value *write_ptr = gep(data_ptr, write_off);
|
||||
Value *index_write_ptr = gep(index_ptr, write_off);
|
||||
// initialize shared memory
|
||||
add_barrier();
|
||||
store(acc, write_ptr);
|
||||
store(acc.first, write_ptr);
|
||||
if (with_index) {
|
||||
store(acc.second, index_write_ptr);
|
||||
}
|
||||
// build result
|
||||
indices_t idx(write_idx.size(), i32(0));
|
||||
for(size_t i = shape[axis]/2; i > 0; i >>= 1){
|
||||
@@ -2468,11 +2507,17 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
Value *read_msk = icmp_ult(lane, i32(i));
|
||||
Value *read_off = select(read_msk, shared_off(shape, order, idx), i32(0));
|
||||
Value *read_ptr = gep(write_ptr, read_off);
|
||||
Value *index_read_ptr = gep(index_write_ptr, read_off);
|
||||
add_barrier();
|
||||
// update accumulator
|
||||
acc = do_acc(acc, load(read_ptr));
|
||||
do_acc(
|
||||
acc, [&]() -> Value * { return load(read_ptr); },
|
||||
[&]() -> Value * { return load(index_read_ptr); }, false);
|
||||
add_barrier();
|
||||
store(acc, write_ptr);
|
||||
store(acc.first, write_ptr);
|
||||
if (with_index) {
|
||||
store(acc.second, index_write_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
add_barrier();
|
||||
@@ -2482,7 +2527,8 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
indices_t read_idx = idx;
|
||||
read_idx.insert(read_idx.begin() + axis, i32(0));
|
||||
Value *read_off = shared_off(shape, order, read_idx);
|
||||
Value *read_ptr = gep(ptr, read_off);
|
||||
Value *read_ptr =
|
||||
with_index ? gep(index_ptr, read_off) : gep(data_ptr, read_off);
|
||||
vals_[x][idx] = load(read_ptr);
|
||||
};
|
||||
}
|
||||
@@ -2494,45 +2540,75 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
// accumulation function
|
||||
ir::reduce_inst::op_t op = x->get_op();
|
||||
auto do_acc = [&](Value *x, Value *y) -> Value* {
|
||||
auto do_acc_op = [&](Value *x, Value *y) -> Value* {
|
||||
switch(op){
|
||||
case ir::reduce_inst::ADD: return add(x, y);
|
||||
case ir::reduce_inst::SUB: return sub(x, y);
|
||||
case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
|
||||
case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
|
||||
case ir::reduce_inst::ARGUMAX: return icmp_uge(x, y);
|
||||
case ir::reduce_inst::ARGUMIN: return icmp_ule(x, y);
|
||||
case ir::reduce_inst::ARGMAX: return icmp_sge(x, y);
|
||||
case ir::reduce_inst::ARGMIN: return icmp_sle(x, y);
|
||||
case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y);
|
||||
case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y);
|
||||
case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
|
||||
case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
|
||||
case ir::reduce_inst::FADD: return fadd(x, y);
|
||||
case ir::reduce_inst::FSUB: return fsub(x, y);
|
||||
case ir::reduce_inst::ARGFMAX: return fcmp_oge(x, y);
|
||||
case ir::reduce_inst::ARGFMIN: return fcmp_ole(x, y);
|
||||
case ir::reduce_inst::FMAX: return max_num(x, y);
|
||||
case ir::reduce_inst::FMIN: return min_num(x, y);
|
||||
case ir::reduce_inst::XOR: return xor_(x, y);
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
};
|
||||
|
||||
auto do_acc = [&](std::pair<Value *, Value *> &acc,
|
||||
std::function<Value *()> load_value_fn,
|
||||
std::function<Value *()> load_index_fn,
|
||||
bool is_first) -> void {
|
||||
auto *val = load_value_fn();
|
||||
if (x->with_index()) {
|
||||
auto *index = load_index_fn();
|
||||
if (is_first) {
|
||||
acc.first = val;
|
||||
acc.second = index;
|
||||
} else {
|
||||
Value *ret = do_acc_op(acc.first, val);
|
||||
acc.first = select(ret, acc.first, val);
|
||||
acc.second = select(ret, acc.second, index);
|
||||
}
|
||||
} else {
|
||||
acc.first = is_first ? val : do_acc_op(acc.first, val);
|
||||
}
|
||||
};
|
||||
|
||||
// neutral element
|
||||
Value *neutral;
|
||||
switch(op) {
|
||||
case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break;
|
||||
case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break;
|
||||
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
|
||||
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
|
||||
case ir::reduce_inst::ARGUMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
|
||||
case ir::reduce_inst::ARGUMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
|
||||
case ir::reduce_inst::ARGMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
|
||||
case ir::reduce_inst::ARGMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
|
||||
case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break;
|
||||
case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break;
|
||||
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
|
||||
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
|
||||
case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break;
|
||||
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
|
||||
case ir::reduce_inst::ARGFMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
|
||||
case ir::reduce_inst::ARGFMIN: neutral = ConstantFP::get(ty, INFINITY); break;
|
||||
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
|
||||
case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break;
|
||||
case ir::reduce_inst::XOR: neutral = neutral = ConstantInt::get(ty, 0); break;
|
||||
case ir::reduce_inst::XOR: neutral = ConstantInt::get(ty, 0); break;
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
ir::value *arg = x->get_operand(0);
|
||||
int cc = tgt_->as_nvidia()->sm();
|
||||
analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline();
|
||||
analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma();
|
||||
bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis());
|
||||
bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1);
|
||||
if(is_coalesced_scanline || is_a100_mma)
|
||||
bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x);
|
||||
bool is_a100_mma = layouts_->is_a100_mma(x);
|
||||
if (is_coalesced_scanline || is_a100_mma)
|
||||
visit_reducend_inst_fast(x, do_acc, neutral);
|
||||
else
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
@@ -2938,6 +3014,13 @@ void generator::forward_declare(ir::function* fn){
|
||||
fns_[fn] = ret;
|
||||
}
|
||||
|
||||
Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout,
|
||||
Type *ty) {
|
||||
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
|
||||
Value *base = bit_cast(shared_ptr_.at(layout), ptr_ty(ty, addr_space));
|
||||
return base;
|
||||
}
|
||||
|
||||
void generator::visit_function(ir::function* fn) {
|
||||
idxs_.clear();
|
||||
vals_.clear();
|
||||
|
@@ -60,15 +60,22 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
|
||||
continue;
|
||||
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
|
||||
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
|
||||
analysis::shared_layout* a_tmp_index = layouts_->has_tmp_index(a) ? layouts_->get(layouts_->tmp_index(a))->to_shared() : nullptr;
|
||||
for(ir::value* b: bs){
|
||||
if(!b->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
|
||||
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
|
||||
analysis::shared_layout* b_tmp_index = layouts_->has_tmp_index(b) ? layouts_->get(layouts_->tmp_index(b))->to_shared() : nullptr;
|
||||
if(intersect_with(a_layout, b_layout) ||
|
||||
intersect_with(a_layout, b_tmp) ||
|
||||
intersect_with(a_layout, b_tmp_index) ||
|
||||
intersect_with(a_tmp, b_layout) ||
|
||||
intersect_with(a_tmp, b_tmp))
|
||||
intersect_with(a_tmp, b_tmp) ||
|
||||
intersect_with(a_tmp, b_tmp_index) ||
|
||||
intersect_with(a_tmp_index, b_layout) ||
|
||||
intersect_with(a_tmp_index, b_tmp) ||
|
||||
intersect_with(a_tmp_index, b_tmp_index))
|
||||
ret.insert(b);
|
||||
}
|
||||
}
|
||||
|
@@ -353,9 +353,6 @@ ir::value *sqrt(ir::value *input, ir::builder *builder) {
|
||||
return builder->create_sqrt(input);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.min
|
||||
----------------------------------------------*/
|
||||
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
|
||||
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
@@ -367,6 +364,9 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
|
||||
throw_not_int_or_float(name);
|
||||
}
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.min
|
||||
----------------------------------------------*/
|
||||
std::string min_docstr = R"pbdoc(
|
||||
Returns the minimum value of `input`.
|
||||
)pbdoc";
|
||||
@@ -374,6 +374,16 @@ ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.arg_min
|
||||
----------------------------------------------*/
|
||||
std::string min_docstr = R"pbdoc(
|
||||
Returns the minimum value's index of `input`.
|
||||
)pbdoc";
|
||||
ir::value *argmin(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "argmin", ir::reduce_inst::ARGFMIN, ir::reduce_inst::ARGMIN);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.max
|
||||
----------------------------------------------*/
|
||||
@@ -384,6 +394,16 @@ ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.arg_max
|
||||
----------------------------------------------*/
|
||||
std::string max_docstr = R"pbdoc(
|
||||
Returns the maximum value's index of `input`.
|
||||
)pbdoc";
|
||||
ir::value *argmax(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "argmax", ir::reduce_inst::ARGFMAX, ir::reduce_inst::ARGMAX);
|
||||
};
|
||||
|
||||
/*----------------------------------------------
|
||||
definition of triton.sum
|
||||
----------------------------------------------*/
|
||||
|
@@ -573,8 +573,14 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("MAX", ir::reduce_inst::MAX)
|
||||
.value("UMIN", ir::reduce_inst::UMIN)
|
||||
.value("UMAX", ir::reduce_inst::UMAX)
|
||||
.value("ARGMIN", ir::reduce_inst::ARGMIN)
|
||||
.value("ARGMAX", ir::reduce_inst::ARGMAX)
|
||||
.value("ARGUMIN", ir::reduce_inst::ARGUMIN)
|
||||
.value("ARGUMAX", ir::reduce_inst::ARGUMAX)
|
||||
.value("FMIN", ir::reduce_inst::FMIN)
|
||||
.value("FMAX", ir::reduce_inst::FMAX)
|
||||
.value("ARGFMIN", ir::reduce_inst::ARGFMIN)
|
||||
.value("ARGFMAX", ir::reduce_inst::ARGFMAX)
|
||||
.value("XOR", ir::reduce_inst::XOR);
|
||||
|
||||
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
|
||||
|
@@ -690,7 +690,7 @@ def test_f16_to_f8_rounding():
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape",
|
||||
[(op, dtype, shape)
|
||||
for op in ['min', 'max', 'sum']
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for dtype in dtypes
|
||||
for shape in [32, 64, 128, 512]])
|
||||
def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
@@ -707,28 +707,37 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
||||
x_tri = to_triton(x, device=device)
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
# numpy result
|
||||
z_ref = numpy_op(x).astype(getattr(np, dtype_str))
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device)
|
||||
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
if op == 'sum':
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, to_numpy(z_tri))
|
||||
if op == 'argmin' or op == 'argmax':
|
||||
# argmin and argmax can have multiple valid indices.
|
||||
# so instead we compare the values pointed by indices
|
||||
np.testing.assert_equal(x[z_ref], x[z_tri])
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
reduce_configs1 = [
|
||||
(op, dtype, (1, 1024), axis) for dtype in dtypes
|
||||
for op in ['min', 'max', 'sum']
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for axis in [1]
|
||||
]
|
||||
reduce_configs2 = [
|
||||
(op, 'float32', shape, 1)
|
||||
for op in ['min', 'max', 'sum']
|
||||
(op, 'float32', shape, axis)
|
||||
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
|
||||
for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
|
||||
for axis in [0, 1]
|
||||
]
|
||||
|
||||
|
||||
@@ -741,7 +750,10 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
range_n = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z + range_m, z)
|
||||
if AXIS == 1:
|
||||
tl.store(Z + range_m, z)
|
||||
else:
|
||||
tl.store(Z + range_n, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
|
||||
# input
|
||||
@@ -749,17 +761,30 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
x_tri = to_triton(x)
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
|
||||
'argmin': np.argmin, 'argmax': np.argmax}[op]
|
||||
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
|
||||
# numpy result
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, dtype_str))
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device)
|
||||
binary = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
||||
device=device)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
if op == 'sum':
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, to_numpy(z_tri))
|
||||
if op == 'argmin' or op == 'argmax':
|
||||
# argmin and argmax can have multiple valid indices.
|
||||
# so instead we compare the values pointed by indices
|
||||
z_ref_index = np.expand_dims(z_ref, axis=axis)
|
||||
z_tri_index = np.expand_dims(z_tri, axis=axis)
|
||||
z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis)
|
||||
z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis)
|
||||
np.testing.assert_equal(z_ref_value, z_tri_value)
|
||||
else:
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
# ---------------
|
||||
# test permute
|
||||
|
@@ -1000,6 +1000,13 @@ def max(input, axis, _builder=None):
|
||||
return semantic.max(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("maximum index")
|
||||
def argmax(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.argmax(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("minimum")
|
||||
def min(input, axis, _builder=None):
|
||||
@@ -1007,6 +1014,13 @@ def min(input, axis, _builder=None):
|
||||
return semantic.min(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("minimum index")
|
||||
def argmin(input, axis, _builder=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.argmin(input, axis, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("sum")
|
||||
def sum(input, axis, _builder=None):
|
||||
|
@@ -961,10 +961,14 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
|
||||
# choose the right unsigned operation
|
||||
if scalar_ty.is_int_unsigned():
|
||||
if INT_OP is ir.REDUCE_OP.MIN:
|
||||
INT_OP = ir.REDUCE_OP.UMIN
|
||||
elif INT_OP is ir.REDUCE_OP.MAX:
|
||||
INT_OP = ir.REDUCE_OP.UMAX
|
||||
int_op_to_unit = {
|
||||
ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN,
|
||||
ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX,
|
||||
ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN,
|
||||
ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX,
|
||||
}
|
||||
if INT_OP in int_op_to_unit:
|
||||
INT_OP = int_op_to_unit[INT_OP]
|
||||
|
||||
# get result type
|
||||
shape = input.type.shape
|
||||
@@ -988,10 +992,18 @@ def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
|
||||
|
||||
|
||||
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
|
||||
|
||||
|
||||
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
|
||||
|
||||
|
||||
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
|
||||
|
||||
|
||||
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
|
||||
|
||||
|
Reference in New Issue
Block a user