Add argmin argmax (#552)
This commit is contained in:
@@ -112,6 +112,8 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
||||
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
|
||||
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
|
||||
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
|
||||
#define fcmp_oge(...) builder_->CreateFCmpOGE(__VA_ARGS__)
|
||||
#define fcmp_ole(...) builder_->CreateFCmpOLE(__VA_ARGS__)
|
||||
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
|
||||
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
|
||||
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
|
||||
@@ -2334,15 +2336,15 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
||||
/**
|
||||
* \brief Code Generation for `reduce` (ND case)
|
||||
*/
|
||||
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral){
|
||||
//
|
||||
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral){
|
||||
ir::value *arg = x->get_operand(0);
|
||||
const auto with_index = x->with_index();
|
||||
unsigned axis = x->get_axis();
|
||||
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
|
||||
std::vector<unsigned> shapes = layout->get_shape();
|
||||
const auto &shapes = layout->get_shape();
|
||||
|
||||
Type* sca_ty = cvt(arg->get_type()->get_scalar_ty());
|
||||
size_t n_bits = sca_ty->getPrimitiveSizeInBits();
|
||||
|
||||
std::string n_bits_str = std::to_string(n_bits);
|
||||
std::string cst = (n_bits == 64) ? "l" : "r";
|
||||
|
||||
@@ -2351,6 +2353,15 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
|
||||
FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false);
|
||||
InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true);
|
||||
|
||||
Type *index_ty = IntegerType::get(*ctx_, 32);
|
||||
FunctionType *st_shared_index_ty =
|
||||
FunctionType::get(void_ty, {i1_ty, ptr_ty(index_ty, 3), index_ty}, false);
|
||||
InlineAsm *st_shared_index = InlineAsm::get(
|
||||
st_shared_index_ty, "@$0 st.shared.b32 [$1], $2;", "b,r,r", true);
|
||||
FunctionType *ld_shared_index_ty =
|
||||
FunctionType::get(index_ty, {i1_ty, ptr_ty(index_ty, 3)}, false);
|
||||
InlineAsm *ld_shared_index = InlineAsm::get(
|
||||
ld_shared_index_ty, "@$1 ld.shared.b32 $0, [$2];", "=r,b,r", true);
|
||||
|
||||
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value* warp = udiv(thread, i32(32));
|
||||
@@ -2362,54 +2373,64 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
|
||||
std::vector<indices_t> arg_idxs = idxs_.at(arg);
|
||||
size_t n_elts = arg_idxs.size();
|
||||
unsigned col_per_thread = 0;
|
||||
Value* warp_i;
|
||||
Value* warp_j;
|
||||
if(analysis::scanline_layout* scanline = layout->to_scanline()){
|
||||
Value* warp_j = nullptr;
|
||||
if (analysis::scanline_layout *scanline = layout->to_scanline()) {
|
||||
std::vector<int> order = layout->get_order();
|
||||
unsigned mts = scanline->mts(order[0]);
|
||||
shuffle_width = std::min<int>(mts, 32);
|
||||
warps_per_inner = std::max<int>(mts/32, 1);
|
||||
warps_per_inner = std::max<int>(mts / 32, 1);
|
||||
col_per_thread = shapes[order[0]] / mts;
|
||||
warp_i = udiv(warp, i32(warps_per_inner));
|
||||
warp_j = urem(warp, i32(warps_per_inner));
|
||||
}
|
||||
else if(layout->to_mma()){
|
||||
shuffle_width = 4;
|
||||
} else if (layout->to_mma()) {
|
||||
shuffle_width = 4;
|
||||
warps_per_inner = layout->to_mma()->wpt(1);
|
||||
col_per_thread = 16;
|
||||
warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id;
|
||||
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
|
||||
}
|
||||
}
|
||||
assert(warp_j != nullptr);
|
||||
|
||||
// unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]);
|
||||
//
|
||||
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
|
||||
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
|
||||
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
|
||||
Value *base = cast_shared_layout_ptr(layouts_->get(layouts_->tmp(x)),
|
||||
cvt(x->get_type()->get_scalar_ty()));
|
||||
Value *index_base =
|
||||
with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
|
||||
IntegerType::get(*ctx_, 32))
|
||||
: nullptr;
|
||||
|
||||
// preds
|
||||
Value* is_lane0 = icmp_eq(lane, i32(0));
|
||||
Value* is_warp0 = icmp_eq(warp, i32(0));
|
||||
Value* is_thread0 = icmp_eq(thread, i32(0));
|
||||
Value* lane_j = urem(lane, i32(shuffle_width));
|
||||
Value* first_lane_in_col = icmp_eq(lane_j, i32(0));
|
||||
add_barrier();
|
||||
// compute partial sum for each warp, and store to shared memory
|
||||
for(size_t i = 0; i < n_elts/col_per_thread; i++){
|
||||
Value* acc;
|
||||
std::pair<Value*, Value*> acc;
|
||||
// reduce within thread
|
||||
for(size_t j = 0; j < col_per_thread; j++){
|
||||
Value* val = arg_vals[arg_idxs[i*col_per_thread + j]];
|
||||
// acc = (j == 0) ? val : do_acc(acc, val);
|
||||
acc = (j == 0) ? val : do_acc(acc, val);
|
||||
auto arg_idx = arg_idxs[i*col_per_thread + j];
|
||||
bool is_first = j == 0;
|
||||
do_acc(
|
||||
acc, [&]() -> Value * { return arg_vals[arg_idx]; },
|
||||
[&]() -> Value * { return arg_idx[axis]; }, is_first);
|
||||
}
|
||||
|
||||
// reduce within warp
|
||||
for(int k = shuffle_width/2 ; k > 0; k >>= 1)
|
||||
acc = do_acc(acc, shfl_sync(acc, k));
|
||||
for(int k = shuffle_width/2 ; k > 0; k >>= 1) {
|
||||
do_acc(
|
||||
acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
|
||||
[&]() -> Value * { return shfl_sync(acc.second, k); }, false);
|
||||
}
|
||||
// store partial result to shared memory
|
||||
auto x_idxs = idxs_[x][i];
|
||||
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
|
||||
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
|
||||
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc});
|
||||
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
|
||||
if (with_index) {
|
||||
call(st_shared_index,
|
||||
{icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
|
||||
}
|
||||
}
|
||||
add_barrier();
|
||||
// at this point, partial accumulator synchronized in shared memory
|
||||
@@ -2418,48 +2439,66 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
|
||||
auto x_idxs = idxs_[x][i];
|
||||
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
|
||||
Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner)));
|
||||
Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
|
||||
for(int k = warps_per_inner/2; k > 0; k >>= 1)
|
||||
acc = do_acc(acc, shfl_sync(acc, k));
|
||||
vals_[x][idxs_[x][i]] = acc;
|
||||
std::pair<Value*, Value*> acc;
|
||||
acc.first = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
|
||||
acc.second = with_index ? call(ld_shared_index, {builder_->getInt1(true),
|
||||
gep(index_base, ld_off)})
|
||||
: nullptr;
|
||||
for (int k = warps_per_inner / 2; k > 0; k >>= 1) {
|
||||
do_acc(
|
||||
acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
|
||||
[&]() -> Value * { return shfl_sync(acc.second, k); }, false);
|
||||
}
|
||||
vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
|
||||
}
|
||||
// add_barrier();
|
||||
}
|
||||
|
||||
void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
|
||||
|
||||
void generator::visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral) {
|
||||
ir::value *arg = x->get_operand(0);
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
unsigned axis = x->get_axis();
|
||||
auto with_index = x->with_index();
|
||||
|
||||
// reduce within thread
|
||||
std::map<indices_t, Value*> accs;
|
||||
// index-><current reduced value, current min/max index (optional)>
|
||||
std::map<indices_t, std::pair<Value*, Value*>> accs;
|
||||
for(indices_t idx: idxs_.at(arg)){
|
||||
indices_t pidx = idx;
|
||||
pidx[axis] = i32(0);
|
||||
Value *current = vals_[arg][idx];
|
||||
bool is_first = accs.find(pidx) == accs.end();
|
||||
accs[pidx] = is_first ? current : do_acc(accs[pidx], current);
|
||||
do_acc(
|
||||
accs[pidx], [&]() -> Value * { return vals_[arg][idx]; },
|
||||
[&]() -> Value * { return idx[axis]; }, is_first);
|
||||
};
|
||||
|
||||
// reduce within blocks
|
||||
analysis::data_layout* layout = layouts_->get(layouts_->tmp(x));
|
||||
Value *base = shared_ptr_.at(layout);
|
||||
auto shape = layout->get_shape();
|
||||
auto order = layout->get_order();
|
||||
int space = base->getType()->getPointerAddressSpace();
|
||||
Value *ptr = bit_cast(base, ptr_ty(ty, space));
|
||||
auto *data_layout = layouts_->get(layouts_->tmp(x));
|
||||
auto *data_ptr =
|
||||
cast_shared_layout_ptr(data_layout, cvt(x->get_type()->get_scalar_ty()));
|
||||
auto *index_ptr =
|
||||
with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
|
||||
IntegerType::get(*ctx_, 32))
|
||||
: data_ptr;
|
||||
|
||||
auto shape = data_layout->get_shape();
|
||||
auto order = data_layout->get_order();
|
||||
Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id;
|
||||
for(auto& x: accs) {
|
||||
// current element being computed
|
||||
Value *&acc = x.second;
|
||||
std::pair<Value *, Value *> acc = x.second;
|
||||
indices_t write_idx = x.first;
|
||||
write_idx[axis] = lane;
|
||||
// shared memory write pointer
|
||||
Value *write_off = shared_off(shape, order, write_idx);
|
||||
Value *write_ptr = gep(ptr, write_off);
|
||||
Value *write_ptr = gep(data_ptr, write_off);
|
||||
Value *index_write_ptr = gep(index_ptr, write_off);
|
||||
// initialize shared memory
|
||||
add_barrier();
|
||||
store(acc, write_ptr);
|
||||
store(acc.first, write_ptr);
|
||||
if (with_index) {
|
||||
store(acc.second, index_write_ptr);
|
||||
}
|
||||
// build result
|
||||
indices_t idx(write_idx.size(), i32(0));
|
||||
for(size_t i = shape[axis]/2; i > 0; i >>= 1){
|
||||
@@ -2468,11 +2507,17 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
Value *read_msk = icmp_ult(lane, i32(i));
|
||||
Value *read_off = select(read_msk, shared_off(shape, order, idx), i32(0));
|
||||
Value *read_ptr = gep(write_ptr, read_off);
|
||||
Value *index_read_ptr = gep(index_write_ptr, read_off);
|
||||
add_barrier();
|
||||
// update accumulator
|
||||
acc = do_acc(acc, load(read_ptr));
|
||||
do_acc(
|
||||
acc, [&]() -> Value * { return load(read_ptr); },
|
||||
[&]() -> Value * { return load(index_read_ptr); }, false);
|
||||
add_barrier();
|
||||
store(acc, write_ptr);
|
||||
store(acc.first, write_ptr);
|
||||
if (with_index) {
|
||||
store(acc.second, index_write_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
add_barrier();
|
||||
@@ -2482,7 +2527,8 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Val
|
||||
indices_t read_idx = idx;
|
||||
read_idx.insert(read_idx.begin() + axis, i32(0));
|
||||
Value *read_off = shared_off(shape, order, read_idx);
|
||||
Value *read_ptr = gep(ptr, read_off);
|
||||
Value *read_ptr =
|
||||
with_index ? gep(index_ptr, read_off) : gep(data_ptr, read_off);
|
||||
vals_[x][idx] = load(read_ptr);
|
||||
};
|
||||
}
|
||||
@@ -2494,45 +2540,75 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
// accumulation function
|
||||
ir::reduce_inst::op_t op = x->get_op();
|
||||
auto do_acc = [&](Value *x, Value *y) -> Value* {
|
||||
auto do_acc_op = [&](Value *x, Value *y) -> Value* {
|
||||
switch(op){
|
||||
case ir::reduce_inst::ADD: return add(x, y);
|
||||
case ir::reduce_inst::SUB: return sub(x, y);
|
||||
case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
|
||||
case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
|
||||
case ir::reduce_inst::ARGUMAX: return icmp_uge(x, y);
|
||||
case ir::reduce_inst::ARGUMIN: return icmp_ule(x, y);
|
||||
case ir::reduce_inst::ARGMAX: return icmp_sge(x, y);
|
||||
case ir::reduce_inst::ARGMIN: return icmp_sle(x, y);
|
||||
case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y);
|
||||
case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y);
|
||||
case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
|
||||
case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
|
||||
case ir::reduce_inst::FADD: return fadd(x, y);
|
||||
case ir::reduce_inst::FSUB: return fsub(x, y);
|
||||
case ir::reduce_inst::ARGFMAX: return fcmp_oge(x, y);
|
||||
case ir::reduce_inst::ARGFMIN: return fcmp_ole(x, y);
|
||||
case ir::reduce_inst::FMAX: return max_num(x, y);
|
||||
case ir::reduce_inst::FMIN: return min_num(x, y);
|
||||
case ir::reduce_inst::XOR: return xor_(x, y);
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
};
|
||||
|
||||
auto do_acc = [&](std::pair<Value *, Value *> &acc,
|
||||
std::function<Value *()> load_value_fn,
|
||||
std::function<Value *()> load_index_fn,
|
||||
bool is_first) -> void {
|
||||
auto *val = load_value_fn();
|
||||
if (x->with_index()) {
|
||||
auto *index = load_index_fn();
|
||||
if (is_first) {
|
||||
acc.first = val;
|
||||
acc.second = index;
|
||||
} else {
|
||||
Value *ret = do_acc_op(acc.first, val);
|
||||
acc.first = select(ret, acc.first, val);
|
||||
acc.second = select(ret, acc.second, index);
|
||||
}
|
||||
} else {
|
||||
acc.first = is_first ? val : do_acc_op(acc.first, val);
|
||||
}
|
||||
};
|
||||
|
||||
// neutral element
|
||||
Value *neutral;
|
||||
switch(op) {
|
||||
case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break;
|
||||
case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break;
|
||||
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
|
||||
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
|
||||
case ir::reduce_inst::ARGUMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
|
||||
case ir::reduce_inst::ARGUMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
|
||||
case ir::reduce_inst::ARGMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
|
||||
case ir::reduce_inst::ARGMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
|
||||
case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break;
|
||||
case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break;
|
||||
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
|
||||
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
|
||||
case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break;
|
||||
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
|
||||
case ir::reduce_inst::ARGFMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
|
||||
case ir::reduce_inst::ARGFMIN: neutral = ConstantFP::get(ty, INFINITY); break;
|
||||
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
|
||||
case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break;
|
||||
case ir::reduce_inst::XOR: neutral = neutral = ConstantInt::get(ty, 0); break;
|
||||
case ir::reduce_inst::XOR: neutral = ConstantInt::get(ty, 0); break;
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
ir::value *arg = x->get_operand(0);
|
||||
int cc = tgt_->as_nvidia()->sm();
|
||||
analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline();
|
||||
analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma();
|
||||
bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis());
|
||||
bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1);
|
||||
if(is_coalesced_scanline || is_a100_mma)
|
||||
bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x);
|
||||
bool is_a100_mma = layouts_->is_a100_mma(x);
|
||||
if (is_coalesced_scanline || is_a100_mma)
|
||||
visit_reducend_inst_fast(x, do_acc, neutral);
|
||||
else
|
||||
visit_reducend_inst(x, do_acc, neutral);
|
||||
@@ -2938,6 +3014,13 @@ void generator::forward_declare(ir::function* fn){
|
||||
fns_[fn] = ret;
|
||||
}
|
||||
|
||||
Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout,
|
||||
Type *ty) {
|
||||
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
|
||||
Value *base = bit_cast(shared_ptr_.at(layout), ptr_ty(ty, addr_space));
|
||||
return base;
|
||||
}
|
||||
|
||||
void generator::visit_function(ir::function* fn) {
|
||||
idxs_.clear();
|
||||
vals_.clear();
|
||||
|
Reference in New Issue
Block a user