Add argmin argmax (#552)

This commit is contained in:
Keren Zhou
2022-06-15 13:55:20 -07:00
committed by GitHub
parent 6b9756532f
commit b5e728cb14
11 changed files with 345 additions and 101 deletions

View File

@@ -308,13 +308,20 @@ private:
void create(size_t id, const std::vector<ir::value*>& values);
public:
void create_tmp_layout(size_t id, data_layout* arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
ir::instruction* i,
bool is_index = false);
public:
// constructor
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
// accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
bool has(size_t id) { return layouts_.find(id) != layouts_.end(); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); }
@@ -322,7 +329,19 @@ public:
std::map<size_t, data_layout*> &get_all() { return layouts_; }
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
int tmp(ir::value* i) { return tmp_.at(i);}
int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); }
int tmp_index(ir::value* i) { return tmp_index_.at(i);}
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
// layout checkers
bool is_scanline(ir::instruction* i);
bool is_coalesced_scanline(ir::instruction* i);
bool is_mma(ir::instruction* i);
bool is_a100_mma(ir::instruction* i);
// execution
void run(ir::module &mod);
@@ -336,6 +355,7 @@ private:
std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_;
std::map<ir::value*, size_t> tmp_index_;
};
}

View File

@@ -118,8 +118,15 @@ private:
llvm::Attribute cvt(ir::attribute attr);
void packed_type(ir::value* i);
void forward_declare(ir::function* fn);
Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty);
public:
private:
typedef std::function<void(
std::pair<Value *, Value *> &acc, std::function<Value *()> load_value_fn,
std::function<Value *()> load_index_fn, bool is_first)>
acc_fn_t;
public:
generator(analysis::axes *a_axes,
analysis::layouts *layouts,
analysis::align *alignment,
@@ -176,9 +183,8 @@ public:
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
Value* shfl_sync(Value* acc, int32_t i);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_layout_convert(ir::value *out, ir::value *in);

View File

@@ -914,7 +914,9 @@ class reduce_inst: public builtin_inst {
public:
enum op_t{
ADD, SUB, MAX, MIN, UMAX, UMIN,
ARGMAX, ARGMIN, ARGUMAX, ARGUMIN,
FADD, FSUB, FMAX, FMIN,
ARGFMAX, ARGFMIN,
XOR
};
@@ -932,12 +934,19 @@ public:
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
op_t get_op() const { return op_; }
bool with_index() const {
return with_index_ops_.find(op_) != with_index_ops_.end();
}
private:
unsigned axis_;
op_t op_;
const static inline std::set<op_t> with_index_ops_ = {
op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX,
op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN};
unsigned axis_;
op_t op_;
};
class select_inst: public builtin_inst {
private:
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);

View File

@@ -588,6 +588,45 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
}
}
// layout checkers
bool layouts::is_scanline(ir::instruction *i) {
return this->get(i->get_operand(0))->to_scanline() != nullptr;
}
bool layouts::is_coalesced_scanline(ir::instruction *i) {
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
auto *scanline = this->get(i->get_operand(0))->to_scanline();
return scanline && scanline->get_order()[0] == red->get_axis();
}
return false;
}
bool layouts::is_mma(ir::instruction *i) {
return this->get(i->get_operand(0))->to_mma() != nullptr;
}
bool layouts::is_a100_mma(ir::instruction *i) {
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) &&
(red->get_axis() == 1);
}
return false;
}
void layouts::create_tmp_layout(size_t id, data_layout *arg,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
ir::instruction *i, bool is_index) {
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
: i->get_type()->get_scalar_ty();
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_);
if (is_index) {
tmp_index_[i] = id;
} else {
tmp_[i] = id;
}
}
void layouts::run(ir::module &mod) {
// make graph
graph_.clear();
@@ -612,22 +651,26 @@ void layouts::run(ir::module &mod) {
// std::cout << "layout: " << std::endl;
// i->print(std::cout);
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
id++;
ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis();
distributed_layout *layout =
dynamic_cast<analysis::distributed_layout *>(get(arg));
// shape
auto shapes = arg->get_type()->get_block_shapes();
distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(get(arg));
shapes[axis] = layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
unsigned axis = red->get_axis();
shapes[axis] =
layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
// create layout
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[red] = id;
id++;
create_tmp_layout(id, layout, axes_->get(arg), shapes, red);
if (red->with_index()) {
id++;
create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true);
}
}
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
id++;
size_t dim = val->get_type()->get_tile_rank();
ir::type::block_shapes_t shape(dim);
for(size_t k = 0; k < dim; k++){
@@ -640,13 +683,12 @@ void layouts::run(ir::module &mod) {
int out_vec = out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec);
shape[out_ord[0]] += pad;
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[val] = id;
id++;
create_tmp_layout(id, out_layout, axes_->get(val), shape, val);
}
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[atom] = id;
create_tmp_layout(id, nullptr, {}, {1}, atom);
}
});

View File

@@ -112,6 +112,8 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
#define fcmp_oge(...) builder_->CreateFCmpOGE(__VA_ARGS__)
#define fcmp_ole(...) builder_->CreateFCmpOLE(__VA_ARGS__)
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
@@ -2334,15 +2336,15 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
/**
* \brief Code Generation for `reduce` (ND case)
*/
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral){
//
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral){
ir::value *arg = x->get_operand(0);
const auto with_index = x->with_index();
unsigned axis = x->get_axis();
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
std::vector<unsigned> shapes = layout->get_shape();
const auto &shapes = layout->get_shape();
Type* sca_ty = cvt(arg->get_type()->get_scalar_ty());
size_t n_bits = sca_ty->getPrimitiveSizeInBits();
std::string n_bits_str = std::to_string(n_bits);
std::string cst = (n_bits == 64) ? "l" : "r";
@@ -2351,6 +2353,15 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false);
InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true);
Type *index_ty = IntegerType::get(*ctx_, 32);
FunctionType *st_shared_index_ty =
FunctionType::get(void_ty, {i1_ty, ptr_ty(index_ty, 3), index_ty}, false);
InlineAsm *st_shared_index = InlineAsm::get(
st_shared_index_ty, "@$0 st.shared.b32 [$1], $2;", "b,r,r", true);
FunctionType *ld_shared_index_ty =
FunctionType::get(index_ty, {i1_ty, ptr_ty(index_ty, 3)}, false);
InlineAsm *ld_shared_index = InlineAsm::get(
ld_shared_index_ty, "@$1 ld.shared.b32 $0, [$2];", "=r,b,r", true);
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value* warp = udiv(thread, i32(32));
@@ -2362,54 +2373,64 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
std::vector<indices_t> arg_idxs = idxs_.at(arg);
size_t n_elts = arg_idxs.size();
unsigned col_per_thread = 0;
Value* warp_i;
Value* warp_j;
if(analysis::scanline_layout* scanline = layout->to_scanline()){
Value* warp_j = nullptr;
if (analysis::scanline_layout *scanline = layout->to_scanline()) {
std::vector<int> order = layout->get_order();
unsigned mts = scanline->mts(order[0]);
shuffle_width = std::min<int>(mts, 32);
warps_per_inner = std::max<int>(mts/32, 1);
warps_per_inner = std::max<int>(mts / 32, 1);
col_per_thread = shapes[order[0]] / mts;
warp_i = udiv(warp, i32(warps_per_inner));
warp_j = urem(warp, i32(warps_per_inner));
}
else if(layout->to_mma()){
} else if (layout->to_mma()) {
shuffle_width = 4;
warps_per_inner = layout->to_mma()->wpt(1);
col_per_thread = 16;
warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id;
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
}
assert(warp_j != nullptr);
// unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]);
//
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
Value *base = cast_shared_layout_ptr(layouts_->get(layouts_->tmp(x)),
cvt(x->get_type()->get_scalar_ty()));
Value *index_base =
with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
IntegerType::get(*ctx_, 32))
: nullptr;
// preds
Value* is_lane0 = icmp_eq(lane, i32(0));
Value* is_warp0 = icmp_eq(warp, i32(0));
Value* is_thread0 = icmp_eq(thread, i32(0));
Value* lane_j = urem(lane, i32(shuffle_width));
Value* first_lane_in_col = icmp_eq(lane_j, i32(0));
add_barrier();
// compute partial sum for each warp, and store to shared memory
for(size_t i = 0; i < n_elts/col_per_thread; i++){
Value* acc;
std::pair<Value*, Value*> acc;
// reduce within thread
for(size_t j = 0; j < col_per_thread; j++){
Value* val = arg_vals[arg_idxs[i*col_per_thread + j]];
// acc = (j == 0) ? val : do_acc(acc, val);
acc = (j == 0) ? val : do_acc(acc, val);
auto arg_idx = arg_idxs[i*col_per_thread + j];
bool is_first = j == 0;
do_acc(
acc, [&]() -> Value * { return arg_vals[arg_idx]; },
[&]() -> Value * { return arg_idx[axis]; }, is_first);
}
// reduce within warp
for(int k = shuffle_width/2 ; k > 0; k >>= 1)
acc = do_acc(acc, shfl_sync(acc, k));
for(int k = shuffle_width/2 ; k > 0; k >>= 1) {
do_acc(
acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
[&]() -> Value * { return shfl_sync(acc.second, k); }, false);
}
// store partial result to shared memory
auto x_idxs = idxs_[x][i];
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc});
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
if (with_index) {
call(st_shared_index,
{icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
}
}
add_barrier();
// at this point, partial accumulator synchronized in shared memory
@@ -2418,48 +2439,66 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
auto x_idxs = idxs_[x][i];
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner)));
Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
for(int k = warps_per_inner/2; k > 0; k >>= 1)
acc = do_acc(acc, shfl_sync(acc, k));
vals_[x][idxs_[x][i]] = acc;
std::pair<Value*, Value*> acc;
acc.first = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
acc.second = with_index ? call(ld_shared_index, {builder_->getInt1(true),
gep(index_base, ld_off)})
: nullptr;
for (int k = warps_per_inner / 2; k > 0; k >>= 1) {
do_acc(
acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
[&]() -> Value * { return shfl_sync(acc.second, k); }, false);
}
vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
}
// add_barrier();
}
void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
void generator::visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral) {
ir::value *arg = x->get_operand(0);
Type *ty = cvt(x->get_type()->get_scalar_ty());
unsigned axis = x->get_axis();
auto with_index = x->with_index();
// reduce within thread
std::map<indices_t, Value*> accs;
// index-><current reduced value, current min/max index (optional)>
std::map<indices_t, std::pair<Value*, Value*>> accs;
for(indices_t idx: idxs_.at(arg)){
indices_t pidx = idx;
pidx[axis] = i32(0);
Value *current = vals_[arg][idx];
bool is_first = accs.find(pidx) == accs.end();
accs[pidx] = is_first ? current : do_acc(accs[pidx], current);
do_acc(
accs[pidx], [&]() -> Value * { return vals_[arg][idx]; },
[&]() -> Value * { return idx[axis]; }, is_first);
};
// reduce within blocks
analysis::data_layout* layout = layouts_->get(layouts_->tmp(x));
Value *base = shared_ptr_.at(layout);
auto shape = layout->get_shape();
auto order = layout->get_order();
int space = base->getType()->getPointerAddressSpace();
Value *ptr = bit_cast(base, ptr_ty(ty, space));
auto *data_layout = layouts_->get(layouts_->tmp(x));
auto *data_ptr =
cast_shared_layout_ptr(data_layout, cvt(x->get_type()->get_scalar_ty()));
auto *index_ptr =
with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
IntegerType::get(*ctx_, 32))
: data_ptr;
auto shape = data_layout->get_shape();
auto order = data_layout->get_order();
Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id;
for(auto& x: accs) {
// current element being computed
Value *&acc = x.second;
std::pair<Value *, Value *> acc = x.second;
indices_t write_idx = x.first;
write_idx[axis] = lane;
// shared memory write pointer
Value *write_off = shared_off(shape, order, write_idx);
Value *write_ptr = gep(ptr, write_off);
Value *write_ptr = gep(data_ptr, write_off);
Value *index_write_ptr = gep(index_ptr, write_off);
// initialize shared memory
add_barrier();
store(acc, write_ptr);
store(acc.first, write_ptr);
if (with_index) {
store(acc.second, index_write_ptr);
}
// build result
indices_t idx(write_idx.size(), i32(0));
for(size_t i = shape[axis]/2; i > 0; i >>= 1){
@@ -2468,11 +2507,17 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Val
Value *read_msk = icmp_ult(lane, i32(i));
Value *read_off = select(read_msk, shared_off(shape, order, idx), i32(0));
Value *read_ptr = gep(write_ptr, read_off);
Value *index_read_ptr = gep(index_write_ptr, read_off);
add_barrier();
// update accumulator
acc = do_acc(acc, load(read_ptr));
do_acc(
acc, [&]() -> Value * { return load(read_ptr); },
[&]() -> Value * { return load(index_read_ptr); }, false);
add_barrier();
store(acc, write_ptr);
store(acc.first, write_ptr);
if (with_index) {
store(acc.second, index_write_ptr);
}
}
}
add_barrier();
@@ -2482,7 +2527,8 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Val
indices_t read_idx = idx;
read_idx.insert(read_idx.begin() + axis, i32(0));
Value *read_off = shared_off(shape, order, read_idx);
Value *read_ptr = gep(ptr, read_off);
Value *read_ptr =
with_index ? gep(index_ptr, read_off) : gep(data_ptr, read_off);
vals_[x][idx] = load(read_ptr);
};
}
@@ -2494,45 +2540,75 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
Type *ty = cvt(x->get_type()->get_scalar_ty());
// accumulation function
ir::reduce_inst::op_t op = x->get_op();
auto do_acc = [&](Value *x, Value *y) -> Value* {
auto do_acc_op = [&](Value *x, Value *y) -> Value* {
switch(op){
case ir::reduce_inst::ADD: return add(x, y);
case ir::reduce_inst::SUB: return sub(x, y);
case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
case ir::reduce_inst::ARGUMAX: return icmp_uge(x, y);
case ir::reduce_inst::ARGUMIN: return icmp_ule(x, y);
case ir::reduce_inst::ARGMAX: return icmp_sge(x, y);
case ir::reduce_inst::ARGMIN: return icmp_sle(x, y);
case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y);
case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y);
case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
case ir::reduce_inst::FADD: return fadd(x, y);
case ir::reduce_inst::FSUB: return fsub(x, y);
case ir::reduce_inst::ARGFMAX: return fcmp_oge(x, y);
case ir::reduce_inst::ARGFMIN: return fcmp_ole(x, y);
case ir::reduce_inst::FMAX: return max_num(x, y);
case ir::reduce_inst::FMIN: return min_num(x, y);
case ir::reduce_inst::XOR: return xor_(x, y);
default: throw std::runtime_error("unreachable");
}
};
auto do_acc = [&](std::pair<Value *, Value *> &acc,
std::function<Value *()> load_value_fn,
std::function<Value *()> load_index_fn,
bool is_first) -> void {
auto *val = load_value_fn();
if (x->with_index()) {
auto *index = load_index_fn();
if (is_first) {
acc.first = val;
acc.second = index;
} else {
Value *ret = do_acc_op(acc.first, val);
acc.first = select(ret, acc.first, val);
acc.second = select(ret, acc.second, index);
}
} else {
acc.first = is_first ? val : do_acc_op(acc.first, val);
}
};
// neutral element
Value *neutral;
switch(op) {
case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::ARGUMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::ARGUMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::ARGMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::ARGMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break;
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break;
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
case ir::reduce_inst::ARGFMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
case ir::reduce_inst::ARGFMIN: neutral = ConstantFP::get(ty, INFINITY); break;
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break;
case ir::reduce_inst::XOR: neutral = neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::XOR: neutral = ConstantInt::get(ty, 0); break;
default: throw std::runtime_error("unreachable");
}
ir::value *arg = x->get_operand(0);
int cc = tgt_->as_nvidia()->sm();
analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline();
analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma();
bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis());
bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1);
if(is_coalesced_scanline || is_a100_mma)
bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x);
bool is_a100_mma = layouts_->is_a100_mma(x);
if (is_coalesced_scanline || is_a100_mma)
visit_reducend_inst_fast(x, do_acc, neutral);
else
visit_reducend_inst(x, do_acc, neutral);
@@ -2938,6 +3014,13 @@ void generator::forward_declare(ir::function* fn){
fns_[fn] = ret;
}
Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout,
Type *ty) {
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
Value *base = bit_cast(shared_ptr_.at(layout), ptr_ty(ty, addr_space));
return base;
}
void generator::visit_function(ir::function* fn) {
idxs_.clear();
vals_.clear();

View File

@@ -60,15 +60,22 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
analysis::shared_layout* a_tmp_index = layouts_->has_tmp_index(a) ? layouts_->get(layouts_->tmp_index(a))->to_shared() : nullptr;
for(ir::value* b: bs){
if(!b->get_type()->is_block_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
analysis::shared_layout* b_tmp_index = layouts_->has_tmp_index(b) ? layouts_->get(layouts_->tmp_index(b))->to_shared() : nullptr;
if(intersect_with(a_layout, b_layout) ||
intersect_with(a_layout, b_tmp) ||
intersect_with(a_layout, b_tmp_index) ||
intersect_with(a_tmp, b_layout) ||
intersect_with(a_tmp, b_tmp))
intersect_with(a_tmp, b_tmp) ||
intersect_with(a_tmp, b_tmp_index) ||
intersect_with(a_tmp_index, b_layout) ||
intersect_with(a_tmp_index, b_tmp) ||
intersect_with(a_tmp_index, b_tmp_index))
ret.insert(b);
}
}

View File

@@ -353,9 +353,6 @@ ir::value *sqrt(ir::value *input, ir::builder *builder) {
return builder->create_sqrt(input);
};
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
@@ -367,6 +364,9 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
throw_not_int_or_float(name);
}
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
std::string min_docstr = R"pbdoc(
Returns the minimum value of `input`.
)pbdoc";
@@ -374,6 +374,16 @@ ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
};
/*----------------------------------------------
definition of triton.arg_min
----------------------------------------------*/
std::string min_docstr = R"pbdoc(
Returns the minimum value's index of `input`.
)pbdoc";
ir::value *argmin(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "argmin", ir::reduce_inst::ARGFMIN, ir::reduce_inst::ARGMIN);
};
/*----------------------------------------------
definition of triton.max
----------------------------------------------*/
@@ -384,6 +394,16 @@ ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
};
/*----------------------------------------------
definition of triton.arg_max
----------------------------------------------*/
std::string max_docstr = R"pbdoc(
Returns the maximum value's index of `input`.
)pbdoc";
ir::value *argmax(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "argmax", ir::reduce_inst::ARGFMAX, ir::reduce_inst::ARGMAX);
};
/*----------------------------------------------
definition of triton.sum
----------------------------------------------*/

View File

@@ -573,8 +573,14 @@ void init_triton_ir(py::module &&m) {
.value("MAX", ir::reduce_inst::MAX)
.value("UMIN", ir::reduce_inst::UMIN)
.value("UMAX", ir::reduce_inst::UMAX)
.value("ARGMIN", ir::reduce_inst::ARGMIN)
.value("ARGMAX", ir::reduce_inst::ARGMAX)
.value("ARGUMIN", ir::reduce_inst::ARGUMIN)
.value("ARGUMAX", ir::reduce_inst::ARGUMAX)
.value("FMIN", ir::reduce_inst::FMIN)
.value("FMAX", ir::reduce_inst::FMAX)
.value("ARGFMIN", ir::reduce_inst::ARGFMIN)
.value("ARGFMAX", ir::reduce_inst::ARGFMAX)
.value("XOR", ir::reduce_inst::XOR);
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")

View File

@@ -690,7 +690,7 @@ def test_f16_to_f8_rounding():
@pytest.mark.parametrize("op, dtype_str, shape",
[(op, dtype, shape)
for op in ['min', 'max', 'sum']
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for dtype in dtypes
for shape in [32, 64, 128, 512]])
def test_reduce1d(op, dtype_str, shape, device='cuda'):
@@ -707,28 +707,37 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'):
# limit the range of integers so that the sum does not overflow
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
x_tri = to_triton(x, device=device)
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op]
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
'argmin': np.argmin, 'argmax': np.argmax}[op]
# numpy result
z_ref = numpy_op(x).astype(getattr(np, dtype_str))
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
# triton result
z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device)
z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device)
kernel[(1,)](x_tri, z_tri, BLOCK=shape)
z_tri = to_numpy(z_tri)
# compare
if op == 'sum':
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
else:
np.testing.assert_equal(z_ref, to_numpy(z_tri))
if op == 'argmin' or op == 'argmax':
# argmin and argmax can have multiple valid indices.
# so instead we compare the values pointed by indices
np.testing.assert_equal(x[z_ref], x[z_tri])
else:
np.testing.assert_equal(z_ref, z_tri)
reduce_configs1 = [
(op, dtype, (1, 1024), axis) for dtype in dtypes
for op in ['min', 'max', 'sum']
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for axis in [1]
]
reduce_configs2 = [
(op, 'float32', shape, 1)
for op in ['min', 'max', 'sum']
(op, 'float32', shape, axis)
for op in ['min', 'max', 'argmin', 'argmax', 'sum']
for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)]
for axis in [0, 1]
]
@@ -741,7 +750,10 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = GENERATE_TEST_HERE
tl.store(Z + range_m, z)
if AXIS == 1:
tl.store(Z + range_m, z)
else:
tl.store(Z + range_n, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
# input
@@ -749,17 +761,30 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
# limit the range of integers so that the sum does not overflow
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
x_tri = to_triton(x)
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op]
numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min,
'argmin': np.argmin, 'argmax': np.argmax}[op]
z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str
# numpy result
z_ref = numpy_op(x, axis=axis).astype(getattr(np, dtype_str))
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# triton result
z_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device)
binary = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
device=device)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
z_tri = to_numpy(z_tri)
# compare
if op == 'sum':
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
else:
np.testing.assert_equal(z_ref, to_numpy(z_tri))
if op == 'argmin' or op == 'argmax':
# argmin and argmax can have multiple valid indices.
# so instead we compare the values pointed by indices
z_ref_index = np.expand_dims(z_ref, axis=axis)
z_tri_index = np.expand_dims(z_tri, axis=axis)
z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis)
z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis)
np.testing.assert_equal(z_ref_value, z_tri_value)
else:
np.testing.assert_equal(z_ref, z_tri)
# ---------------
# test permute

View File

@@ -1000,6 +1000,13 @@ def max(input, axis, _builder=None):
return semantic.max(input, axis, _builder)
@builtin
@_add_reduction_docstr("maximum index")
def argmax(input, axis, _builder=None):
axis = _constexpr_to_value(axis)
return semantic.argmax(input, axis, _builder)
@builtin
@_add_reduction_docstr("minimum")
def min(input, axis, _builder=None):
@@ -1007,6 +1014,13 @@ def min(input, axis, _builder=None):
return semantic.min(input, axis, _builder)
@builtin
@_add_reduction_docstr("minimum index")
def argmin(input, axis, _builder=None):
axis = _constexpr_to_value(axis)
return semantic.argmin(input, axis, _builder)
@builtin
@_add_reduction_docstr("sum")
def sum(input, axis, _builder=None):

View File

@@ -961,10 +961,14 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
# choose the right unsigned operation
if scalar_ty.is_int_unsigned():
if INT_OP is ir.REDUCE_OP.MIN:
INT_OP = ir.REDUCE_OP.UMIN
elif INT_OP is ir.REDUCE_OP.MAX:
INT_OP = ir.REDUCE_OP.UMAX
int_op_to_unit = {
ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN,
ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX,
ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN,
ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX,
}
if INT_OP in int_op_to_unit:
INT_OP = int_op_to_unit[INT_OP]
# get result type
shape = input.type.shape
@@ -988,10 +992,18 @@ def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)