diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 050ac6956..99481f694 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -308,13 +308,20 @@ private: void create(size_t id, const std::vector& values); -public: + void create_tmp_layout(size_t id, data_layout* arg, + const std::vector& axes, + const std::vector& shape, + ir::instruction* i, + bool is_index = false); + + public: // constructor layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt); // accessors unsigned layout_of(ir::value *value) const { return groups_.at(value); } bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); } + bool has(size_t id) { return layouts_.find(id) != layouts_.end(); } const std::vector& values_of(unsigned id) const { return values_.at(id); } size_t num_layouts() const { return values_.size();} data_layout* get(size_t id) { return layouts_.at(id); } @@ -322,7 +329,19 @@ public: std::map &get_all() { return layouts_; } bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); } int tmp(ir::value* i) { return tmp_.at(i);} + int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); } + int tmp_index(ir::value* i) { return tmp_index_.at(i);} void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; } + + // layout checkers + bool is_scanline(ir::instruction* i); + + bool is_coalesced_scanline(ir::instruction* i); + + bool is_mma(ir::instruction* i); + + bool is_a100_mma(ir::instruction* i); + // execution void run(ir::module &mod); @@ -336,6 +355,7 @@ private: std::map> values_; std::map layouts_; std::map tmp_; + std::map tmp_index_; }; } diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 945b9b074..b408a46ca 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -118,8 +118,15 @@ private: llvm::Attribute cvt(ir::attribute attr); void packed_type(ir::value* i); void forward_declare(ir::function* fn); + Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty); -public: + private: + typedef std::function &acc, std::function load_value_fn, + std::function load_index_fn, bool is_first)> + acc_fn_t; + + public: generator(analysis::axes *a_axes, analysis::layouts *layouts, analysis::align *alignment, @@ -176,9 +183,8 @@ public: void visit_trans_inst(ir::trans_inst*); void visit_sqrt_inst(ir::sqrt_inst*); Value* shfl_sync(Value* acc, int32_t i); - void visit_reduce1d_inst(ir::reduce_inst*, std::function, Value*); - void visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral); - void visit_reducend_inst(ir::reduce_inst*, std::function, Value*); + void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral); + void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral); void visit_reduce_inst(ir::reduce_inst*); void visit_select_inst(ir::select_inst*); void visit_layout_convert(ir::value *out, ir::value *in); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index ee7897e03..734ea2b42 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -914,7 +914,9 @@ class reduce_inst: public builtin_inst { public: enum op_t{ ADD, SUB, MAX, MIN, UMAX, UMIN, + ARGMAX, ARGMIN, ARGUMAX, ARGUMIN, FADD, FSUB, FMAX, FMIN, + ARGFMAX, ARGFMIN, XOR }; @@ -932,12 +934,19 @@ public: static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr); unsigned get_axis() const { return axis_; } op_t get_op() const { return op_; } + bool with_index() const { + return with_index_ops_.find(op_) != with_index_ops_.end(); + } private: - unsigned axis_; - op_t op_; + const static inline std::set with_index_ops_ = { + op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX, + op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN}; + unsigned axis_; + op_t op_; }; + class select_inst: public builtin_inst { private: select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next); diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 86473dc54..a19be19ef 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -588,6 +588,45 @@ void layouts::create(size_t id, const std::vector& values) { } } +// layout checkers +bool layouts::is_scanline(ir::instruction *i) { + return this->get(i->get_operand(0))->to_scanline() != nullptr; +} + +bool layouts::is_coalesced_scanline(ir::instruction *i) { + if (auto *red = dynamic_cast(i)) { + auto *scanline = this->get(i->get_operand(0))->to_scanline(); + return scanline && scanline->get_order()[0] == red->get_axis(); + } + return false; +} + +bool layouts::is_mma(ir::instruction *i) { + return this->get(i->get_operand(0))->to_mma() != nullptr; +} + +bool layouts::is_a100_mma(ir::instruction *i) { + if (auto *red = dynamic_cast(i)) { + return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) && + (red->get_axis() == 1); + } + return false; +} + +void layouts::create_tmp_layout(size_t id, data_layout *arg, + const std::vector &axes, + const std::vector &shape, + ir::instruction *i, bool is_index) { + ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context()) + : i->get_type()->get_scalar_ty(); + layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_); + if (is_index) { + tmp_index_[i] = id; + } else { + tmp_[i] = id; + } +} + void layouts::run(ir::module &mod) { // make graph graph_.clear(); @@ -612,22 +651,26 @@ void layouts::run(ir::module &mod) { // std::cout << "layout: " << std::endl; // i->print(std::cout); if(auto *red = dynamic_cast(i)) { - id++; ir::value *arg = red->get_operand(0); - unsigned axis = red->get_axis(); + distributed_layout *layout = + dynamic_cast(get(arg)); // shape auto shapes = arg->get_type()->get_block_shapes(); - distributed_layout* layout = dynamic_cast(get(arg)); - shapes[axis] = layout->shape_per_cta(axis) / layout->contig_per_thread(axis); - + unsigned axis = red->get_axis(); + shapes[axis] = + layout->shape_per_cta(axis) / layout->contig_per_thread(axis); // create layout - layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_); - tmp_[red] = id; + id++; + create_tmp_layout(id, layout, axes_->get(arg), shapes, red); + + if (red->with_index()) { + id++; + create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true); + } } if(auto *val = dynamic_cast(i)){ distributed_layout* out_layout = dynamic_cast(get(val)); distributed_layout* in_layout = dynamic_cast(get(i->get_operand(0))); - id++; size_t dim = val->get_type()->get_tile_rank(); ir::type::block_shapes_t shape(dim); for(size_t k = 0; k < dim; k++){ @@ -640,13 +683,12 @@ void layouts::run(ir::module &mod) { int out_vec = out_layout->contig_per_thread(out_ord[0]); int pad = std::max(in_vec, out_vec); shape[out_ord[0]] += pad; - layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_); - tmp_[val] = id; + id++; + create_tmp_layout(id, out_layout, axes_->get(val), shape, val); } if(auto *atom = dynamic_cast(i)){ id++; - layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_); - tmp_[atom] = id; + create_tmp_layout(id, nullptr, {}, {1}, atom); } }); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 53cfb70fc..ebd21732b 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -112,6 +112,8 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ #define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__) #define fadd(...) builder_->CreateFAdd(__VA_ARGS__) #define fcmp(...) builder_->CreateFCmp(__VA_ARGS__) +#define fcmp_oge(...) builder_->CreateFCmpOGE(__VA_ARGS__) +#define fcmp_ole(...) builder_->CreateFCmpOLE(__VA_ARGS__) #define fmul(...) builder_->CreateFMul(__VA_ARGS__) #define fpcast(...) builder_->CreateFPCast(__VA_ARGS__) #define fsub(...) builder_->CreateFSub(__VA_ARGS__) @@ -2334,15 +2336,15 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){ /** * \brief Code Generation for `reduce` (ND case) */ -void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function do_acc, Value *neutral){ - // +void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral){ ir::value *arg = x->get_operand(0); + const auto with_index = x->with_index(); + unsigned axis = x->get_axis(); analysis::distributed_layout* layout = dynamic_cast(layouts_->get(arg)); - std::vector shapes = layout->get_shape(); + const auto &shapes = layout->get_shape(); Type* sca_ty = cvt(arg->get_type()->get_scalar_ty()); size_t n_bits = sca_ty->getPrimitiveSizeInBits(); - std::string n_bits_str = std::to_string(n_bits); std::string cst = (n_bits == 64) ? "l" : "r"; @@ -2351,6 +2353,15 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::functionget_local_id(mod_, *builder_, 0); Value* warp = udiv(thread, i32(32)); @@ -2362,54 +2373,64 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function arg_idxs = idxs_.at(arg); size_t n_elts = arg_idxs.size(); unsigned col_per_thread = 0; - Value* warp_i; - Value* warp_j; - if(analysis::scanline_layout* scanline = layout->to_scanline()){ + Value* warp_j = nullptr; + if (analysis::scanline_layout *scanline = layout->to_scanline()) { std::vector order = layout->get_order(); unsigned mts = scanline->mts(order[0]); shuffle_width = std::min(mts, 32); - warps_per_inner = std::max(mts/32, 1); + warps_per_inner = std::max(mts / 32, 1); col_per_thread = shapes[order[0]] / mts; - warp_i = udiv(warp, i32(warps_per_inner)); warp_j = urem(warp, i32(warps_per_inner)); - } - else if(layout->to_mma()){ - shuffle_width = 4; + } else if (layout->to_mma()) { + shuffle_width = 4; warps_per_inner = layout->to_mma()->wpt(1); col_per_thread = 16; - warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id; warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id; - } + } + assert(warp_j != nullptr); // unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]); // - Type *ret_ty = cvt(x->get_type()->get_scalar_ty()); - unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); - Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); + Value *base = cast_shared_layout_ptr(layouts_->get(layouts_->tmp(x)), + cvt(x->get_type()->get_scalar_ty())); + Value *index_base = + with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)), + IntegerType::get(*ctx_, 32)) + : nullptr; + // preds Value* is_lane0 = icmp_eq(lane, i32(0)); Value* is_warp0 = icmp_eq(warp, i32(0)); Value* is_thread0 = icmp_eq(thread, i32(0)); Value* lane_j = urem(lane, i32(shuffle_width)); - Value* first_lane_in_col = icmp_eq(lane_j, i32(0)); add_barrier(); // compute partial sum for each warp, and store to shared memory for(size_t i = 0; i < n_elts/col_per_thread; i++){ - Value* acc; + std::pair acc; // reduce within thread for(size_t j = 0; j < col_per_thread; j++){ - Value* val = arg_vals[arg_idxs[i*col_per_thread + j]]; - // acc = (j == 0) ? val : do_acc(acc, val); - acc = (j == 0) ? val : do_acc(acc, val); + auto arg_idx = arg_idxs[i*col_per_thread + j]; + bool is_first = j == 0; + do_acc( + acc, [&]() -> Value * { return arg_vals[arg_idx]; }, + [&]() -> Value * { return arg_idx[axis]; }, is_first); } + // reduce within warp - for(int k = shuffle_width/2 ; k > 0; k >>= 1) - acc = do_acc(acc, shfl_sync(acc, k)); + for(int k = shuffle_width/2 ; k > 0; k >>= 1) { + do_acc( + acc, [&]() -> Value * { return shfl_sync(acc.first, k); }, + [&]() -> Value * { return shfl_sync(acc.second, k); }, false); + } // store partial result to shared memory auto x_idxs = idxs_[x][i]; Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); - call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc}); + call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first}); + if (with_index) { + call(st_shared_index, + {icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second}); + } } add_barrier(); // at this point, partial accumulator synchronized in shared memory @@ -2418,48 +2439,66 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::functiongetInt32(0) : x_idxs[0]; Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner))); - Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)}); - for(int k = warps_per_inner/2; k > 0; k >>= 1) - acc = do_acc(acc, shfl_sync(acc, k)); - vals_[x][idxs_[x][i]] = acc; + std::pair acc; + acc.first = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)}); + acc.second = with_index ? call(ld_shared_index, {builder_->getInt1(true), + gep(index_base, ld_off)}) + : nullptr; + for (int k = warps_per_inner / 2; k > 0; k >>= 1) { + do_acc( + acc, [&]() -> Value * { return shfl_sync(acc.first, k); }, + [&]() -> Value * { return shfl_sync(acc.second, k); }, false); + } + vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first; } // add_barrier(); } -void generator::visit_reducend_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { + +void generator::visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral) { ir::value *arg = x->get_operand(0); - Type *ty = cvt(x->get_type()->get_scalar_ty()); unsigned axis = x->get_axis(); + auto with_index = x->with_index(); // reduce within thread - std::map accs; + // index-> + std::map> accs; for(indices_t idx: idxs_.at(arg)){ indices_t pidx = idx; pidx[axis] = i32(0); - Value *current = vals_[arg][idx]; bool is_first = accs.find(pidx) == accs.end(); - accs[pidx] = is_first ? current : do_acc(accs[pidx], current); + do_acc( + accs[pidx], [&]() -> Value * { return vals_[arg][idx]; }, + [&]() -> Value * { return idx[axis]; }, is_first); }; // reduce within blocks - analysis::data_layout* layout = layouts_->get(layouts_->tmp(x)); - Value *base = shared_ptr_.at(layout); - auto shape = layout->get_shape(); - auto order = layout->get_order(); - int space = base->getType()->getPointerAddressSpace(); - Value *ptr = bit_cast(base, ptr_ty(ty, space)); + auto *data_layout = layouts_->get(layouts_->tmp(x)); + auto *data_ptr = + cast_shared_layout_ptr(data_layout, cvt(x->get_type()->get_scalar_ty())); + auto *index_ptr = + with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)), + IntegerType::get(*ctx_, 32)) + : data_ptr; + + auto shape = data_layout->get_shape(); + auto order = data_layout->get_order(); Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id; for(auto& x: accs) { // current element being computed - Value *&acc = x.second; + std::pair acc = x.second; indices_t write_idx = x.first; write_idx[axis] = lane; // shared memory write pointer Value *write_off = shared_off(shape, order, write_idx); - Value *write_ptr = gep(ptr, write_off); + Value *write_ptr = gep(data_ptr, write_off); + Value *index_write_ptr = gep(index_ptr, write_off); // initialize shared memory add_barrier(); - store(acc, write_ptr); + store(acc.first, write_ptr); + if (with_index) { + store(acc.second, index_write_ptr); + } // build result indices_t idx(write_idx.size(), i32(0)); for(size_t i = shape[axis]/2; i > 0; i >>= 1){ @@ -2468,11 +2507,17 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function Value * { return load(read_ptr); }, + [&]() -> Value * { return load(index_read_ptr); }, false); add_barrier(); - store(acc, write_ptr); + store(acc.first, write_ptr); + if (with_index) { + store(acc.second, index_write_ptr); + } } } add_barrier(); @@ -2482,7 +2527,8 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::functionget_type()->get_scalar_ty()); // accumulation function ir::reduce_inst::op_t op = x->get_op(); - auto do_acc = [&](Value *x, Value *y) -> Value* { + auto do_acc_op = [&](Value *x, Value *y) -> Value* { switch(op){ case ir::reduce_inst::ADD: return add(x, y); case ir::reduce_inst::SUB: return sub(x, y); - case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y); - case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y); + case ir::reduce_inst::ARGUMAX: return icmp_uge(x, y); + case ir::reduce_inst::ARGUMIN: return icmp_ule(x, y); + case ir::reduce_inst::ARGMAX: return icmp_sge(x, y); + case ir::reduce_inst::ARGMIN: return icmp_sle(x, y); case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y); case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y); + case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y); + case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y); case ir::reduce_inst::FADD: return fadd(x, y); case ir::reduce_inst::FSUB: return fsub(x, y); + case ir::reduce_inst::ARGFMAX: return fcmp_oge(x, y); + case ir::reduce_inst::ARGFMIN: return fcmp_ole(x, y); case ir::reduce_inst::FMAX: return max_num(x, y); case ir::reduce_inst::FMIN: return min_num(x, y); case ir::reduce_inst::XOR: return xor_(x, y); default: throw std::runtime_error("unreachable"); } }; + + auto do_acc = [&](std::pair &acc, + std::function load_value_fn, + std::function load_index_fn, + bool is_first) -> void { + auto *val = load_value_fn(); + if (x->with_index()) { + auto *index = load_index_fn(); + if (is_first) { + acc.first = val; + acc.second = index; + } else { + Value *ret = do_acc_op(acc.first, val); + acc.first = select(ret, acc.first, val); + acc.second = select(ret, acc.second, index); + } + } else { + acc.first = is_first ? val : do_acc_op(acc.first, val); + } + }; + // neutral element Value *neutral; switch(op) { case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break; case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break; - case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; - case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; + case ir::reduce_inst::ARGUMAX: neutral = ConstantInt::get(ty, INT32_MIN); break; + case ir::reduce_inst::ARGUMIN: neutral = ConstantInt::get(ty, INT32_MAX); break; + case ir::reduce_inst::ARGMAX: neutral = ConstantInt::get(ty, INT32_MIN); break; + case ir::reduce_inst::ARGMIN: neutral = ConstantInt::get(ty, INT32_MAX); break; case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break; case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break; + case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; + case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break; + case ir::reduce_inst::ARGFMAX: neutral = ConstantFP::get(ty, -INFINITY); break; + case ir::reduce_inst::ARGFMIN: neutral = ConstantFP::get(ty, INFINITY); break; case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break; case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break; - case ir::reduce_inst::XOR: neutral = neutral = ConstantInt::get(ty, 0); break; + case ir::reduce_inst::XOR: neutral = ConstantInt::get(ty, 0); break; default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); - int cc = tgt_->as_nvidia()->sm(); - analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline(); - analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma(); - bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis()); - bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1); - if(is_coalesced_scanline || is_a100_mma) + bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x); + bool is_a100_mma = layouts_->is_a100_mma(x); + if (is_coalesced_scanline || is_a100_mma) visit_reducend_inst_fast(x, do_acc, neutral); else visit_reducend_inst(x, do_acc, neutral); @@ -2938,6 +3014,13 @@ void generator::forward_declare(ir::function* fn){ fns_[fn] = ret; } +Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout, + Type *ty) { + unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); + Value *base = bit_cast(shared_ptr_.at(layout), ptr_ty(ty, addr_space)); + return base; +} + void generator::visit_function(ir::function* fn) { idxs_.clear(); vals_.clear(); diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 96249bcd5..22fe00fe6 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -60,15 +60,22 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b continue; analysis::shared_layout* a_layout = layouts_->get(a)->to_shared(); analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr; + analysis::shared_layout* a_tmp_index = layouts_->has_tmp_index(a) ? layouts_->get(layouts_->tmp_index(a))->to_shared() : nullptr; for(ir::value* b: bs){ if(!b->get_type()->is_block_ty()) continue; analysis::shared_layout* b_layout = layouts_->get(b)->to_shared(); analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr; + analysis::shared_layout* b_tmp_index = layouts_->has_tmp_index(b) ? layouts_->get(layouts_->tmp_index(b))->to_shared() : nullptr; if(intersect_with(a_layout, b_layout) || intersect_with(a_layout, b_tmp) || + intersect_with(a_layout, b_tmp_index) || intersect_with(a_tmp, b_layout) || - intersect_with(a_tmp, b_tmp)) + intersect_with(a_tmp, b_tmp) || + intersect_with(a_tmp, b_tmp_index) || + intersect_with(a_tmp_index, b_layout) || + intersect_with(a_tmp_index, b_tmp) || + intersect_with(a_tmp_index, b_tmp_index)) ret.insert(b); } } diff --git a/python/src/functions.h b/python/src/functions.h index 19f7e7eb9..d5b6c15ef 100644 --- a/python/src/functions.h +++ b/python/src/functions.h @@ -353,9 +353,6 @@ ir::value *sqrt(ir::value *input, ir::builder *builder) { return builder->create_sqrt(input); }; -/*---------------------------------------------- - definition of triton.min - ----------------------------------------------*/ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { ir::type *scalar_ty = input->get_type()->get_scalar_ty(); @@ -367,6 +364,9 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder throw_not_int_or_float(name); } +/*---------------------------------------------- + definition of triton.min + ----------------------------------------------*/ std::string min_docstr = R"pbdoc( Returns the minimum value of `input`. )pbdoc"; @@ -374,6 +374,16 @@ ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) { return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN); }; +/*---------------------------------------------- + definition of triton.arg_min + ----------------------------------------------*/ +std::string min_docstr = R"pbdoc( + Returns the minimum value's index of `input`. + )pbdoc"; +ir::value *argmin(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "argmin", ir::reduce_inst::ARGFMIN, ir::reduce_inst::ARGMIN); +}; + /*---------------------------------------------- definition of triton.max ----------------------------------------------*/ @@ -384,6 +394,16 @@ ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) { return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX); }; +/*---------------------------------------------- + definition of triton.arg_max + ----------------------------------------------*/ +std::string max_docstr = R"pbdoc( + Returns the maximum value's index of `input`. + )pbdoc"; +ir::value *argmax(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "argmax", ir::reduce_inst::ARGFMAX, ir::reduce_inst::ARGMAX); +}; + /*---------------------------------------------- definition of triton.sum ----------------------------------------------*/ diff --git a/python/src/triton.cc b/python/src/triton.cc index 7ebd6b9b9..4e1849733 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -573,8 +573,14 @@ void init_triton_ir(py::module &&m) { .value("MAX", ir::reduce_inst::MAX) .value("UMIN", ir::reduce_inst::UMIN) .value("UMAX", ir::reduce_inst::UMAX) + .value("ARGMIN", ir::reduce_inst::ARGMIN) + .value("ARGMAX", ir::reduce_inst::ARGMAX) + .value("ARGUMIN", ir::reduce_inst::ARGUMIN) + .value("ARGUMAX", ir::reduce_inst::ARGUMAX) .value("FMIN", ir::reduce_inst::FMIN) .value("FMAX", ir::reduce_inst::FMAX) + .value("ARGFMIN", ir::reduce_inst::ARGFMIN) + .value("ARGFMAX", ir::reduce_inst::ARGFMAX) .value("XOR", ir::reduce_inst::XOR); py::enum_(m, "ATOMIC_OP") diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 348672822..f1b4f899f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -690,7 +690,7 @@ def test_f16_to_f8_rounding(): @pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) - for op in ['min', 'max', 'sum'] + for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for dtype in dtypes for shape in [32, 64, 128, 512]]) def test_reduce1d(op, dtype_str, shape, device='cuda'): @@ -707,28 +707,37 @@ def test_reduce1d(op, dtype_str, shape, device='cuda'): # limit the range of integers so that the sum does not overflow x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) x_tri = to_triton(x, device=device) - numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op] + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, + 'argmin': np.argmin, 'argmax': np.argmax}[op] # numpy result - z_ref = numpy_op(x).astype(getattr(np, dtype_str)) + z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) # triton result - z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device) + z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), device=device) kernel[(1,)](x_tri, z_tri, BLOCK=shape) + z_tri = to_numpy(z_tri) # compare if op == 'sum': - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) else: - np.testing.assert_equal(z_ref, to_numpy(z_tri)) + if op == 'argmin' or op == 'argmax': + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + np.testing.assert_equal(x[z_ref], x[z_tri]) + else: + np.testing.assert_equal(z_ref, z_tri) reduce_configs1 = [ (op, dtype, (1, 1024), axis) for dtype in dtypes - for op in ['min', 'max', 'sum'] + for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for axis in [1] ] reduce_configs2 = [ - (op, 'float32', shape, 1) - for op in ['min', 'max', 'sum'] + (op, 'float32', shape, axis) + for op in ['min', 'max', 'argmin', 'argmax', 'sum'] for shape in [(2, 32), (4, 32), (4, 128), (32, 64), (64, 128), (128, 256), (32, 1024)] + for axis in [0, 1] ] @@ -741,7 +750,10 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): range_n = tl.arange(0, BLOCK_N) x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) z = GENERATE_TEST_HERE - tl.store(Z + range_m, z) + if AXIS == 1: + tl.store(Z + range_m, z) + else: + tl.store(Z + range_n, z) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'}) # input @@ -749,17 +761,30 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): # limit the range of integers so that the sum does not overflow x = numpy_random(shape, dtype_str=dtype_str, rs=rs) x_tri = to_triton(x) - numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min}[op] + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, + 'argmin': np.argmin, 'argmax': np.argmax}[op] + z_dtype_str = 'int32' if op == 'argmin' or op == 'argmax' else dtype_str # numpy result - z_ref = numpy_op(x, axis=axis).astype(getattr(np, dtype_str)) + z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) # triton result - z_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device) - binary = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) + z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs), + device=device) + kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) + z_tri = to_numpy(z_tri) # compare if op == 'sum': - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) else: - np.testing.assert_equal(z_ref, to_numpy(z_tri)) + if op == 'argmin' or op == 'argmax': + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + z_ref_index = np.expand_dims(z_ref, axis=axis) + z_tri_index = np.expand_dims(z_tri, axis=axis) + z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) + z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) + np.testing.assert_equal(z_ref_value, z_tri_value) + else: + np.testing.assert_equal(z_ref, z_tri) # --------------- # test permute diff --git a/python/triton/language/core.py b/python/triton/language/core.py index fa6f190e3..d775abf40 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1000,6 +1000,13 @@ def max(input, axis, _builder=None): return semantic.max(input, axis, _builder) +@builtin +@_add_reduction_docstr("maximum index") +def argmax(input, axis, _builder=None): + axis = _constexpr_to_value(axis) + return semantic.argmax(input, axis, _builder) + + @builtin @_add_reduction_docstr("minimum") def min(input, axis, _builder=None): @@ -1007,6 +1014,13 @@ def min(input, axis, _builder=None): return semantic.min(input, axis, _builder) +@builtin +@_add_reduction_docstr("minimum index") +def argmin(input, axis, _builder=None): + axis = _constexpr_to_value(axis) + return semantic.argmin(input, axis, _builder) + + @builtin @_add_reduction_docstr("sum") def sum(input, axis, _builder=None): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index e57faa5ec..15a5cb648 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -961,10 +961,14 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, # choose the right unsigned operation if scalar_ty.is_int_unsigned(): - if INT_OP is ir.REDUCE_OP.MIN: - INT_OP = ir.REDUCE_OP.UMIN - elif INT_OP is ir.REDUCE_OP.MAX: - INT_OP = ir.REDUCE_OP.UMAX + int_op_to_unit = { + ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN, + ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX, + ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN, + ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX, + } + if INT_OP in int_op_to_unit: + INT_OP = int_op_to_unit[INT_OP] # get result type shape = input.type.shape @@ -988,10 +992,18 @@ def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN) +def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN) + + def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX) +def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX) + + def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)