Add argmin argmax (#552)

This commit is contained in:
Keren Zhou
2022-06-15 13:55:20 -07:00
committed by GitHub
parent 6b9756532f
commit b5e728cb14
11 changed files with 345 additions and 101 deletions

View File

@@ -588,6 +588,45 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
}
}
// layout checkers
bool layouts::is_scanline(ir::instruction *i) {
return this->get(i->get_operand(0))->to_scanline() != nullptr;
}
bool layouts::is_coalesced_scanline(ir::instruction *i) {
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
auto *scanline = this->get(i->get_operand(0))->to_scanline();
return scanline && scanline->get_order()[0] == red->get_axis();
}
return false;
}
bool layouts::is_mma(ir::instruction *i) {
return this->get(i->get_operand(0))->to_mma() != nullptr;
}
bool layouts::is_a100_mma(ir::instruction *i) {
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) &&
(red->get_axis() == 1);
}
return false;
}
void layouts::create_tmp_layout(size_t id, data_layout *arg,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
ir::instruction *i, bool is_index) {
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
: i->get_type()->get_scalar_ty();
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_);
if (is_index) {
tmp_index_[i] = id;
} else {
tmp_[i] = id;
}
}
void layouts::run(ir::module &mod) {
// make graph
graph_.clear();
@@ -612,22 +651,26 @@ void layouts::run(ir::module &mod) {
// std::cout << "layout: " << std::endl;
// i->print(std::cout);
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
id++;
ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis();
distributed_layout *layout =
dynamic_cast<analysis::distributed_layout *>(get(arg));
// shape
auto shapes = arg->get_type()->get_block_shapes();
distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(get(arg));
shapes[axis] = layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
unsigned axis = red->get_axis();
shapes[axis] =
layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
// create layout
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[red] = id;
id++;
create_tmp_layout(id, layout, axes_->get(arg), shapes, red);
if (red->with_index()) {
id++;
create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true);
}
}
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
id++;
size_t dim = val->get_type()->get_tile_rank();
ir::type::block_shapes_t shape(dim);
for(size_t k = 0; k < dim; k++){
@@ -640,13 +683,12 @@ void layouts::run(ir::module &mod) {
int out_vec = out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec);
shape[out_ord[0]] += pad;
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[val] = id;
id++;
create_tmp_layout(id, out_layout, axes_->get(val), shape, val);
}
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[atom] = id;
create_tmp_layout(id, nullptr, {}, {1}, atom);
}
});

View File

@@ -112,6 +112,8 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
#define fcmp_oge(...) builder_->CreateFCmpOGE(__VA_ARGS__)
#define fcmp_ole(...) builder_->CreateFCmpOLE(__VA_ARGS__)
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
@@ -2334,15 +2336,15 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
/**
* \brief Code Generation for `reduce` (ND case)
*/
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral){
//
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral){
ir::value *arg = x->get_operand(0);
const auto with_index = x->with_index();
unsigned axis = x->get_axis();
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
std::vector<unsigned> shapes = layout->get_shape();
const auto &shapes = layout->get_shape();
Type* sca_ty = cvt(arg->get_type()->get_scalar_ty());
size_t n_bits = sca_ty->getPrimitiveSizeInBits();
std::string n_bits_str = std::to_string(n_bits);
std::string cst = (n_bits == 64) ? "l" : "r";
@@ -2351,6 +2353,15 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false);
InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true);
Type *index_ty = IntegerType::get(*ctx_, 32);
FunctionType *st_shared_index_ty =
FunctionType::get(void_ty, {i1_ty, ptr_ty(index_ty, 3), index_ty}, false);
InlineAsm *st_shared_index = InlineAsm::get(
st_shared_index_ty, "@$0 st.shared.b32 [$1], $2;", "b,r,r", true);
FunctionType *ld_shared_index_ty =
FunctionType::get(index_ty, {i1_ty, ptr_ty(index_ty, 3)}, false);
InlineAsm *ld_shared_index = InlineAsm::get(
ld_shared_index_ty, "@$1 ld.shared.b32 $0, [$2];", "=r,b,r", true);
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value* warp = udiv(thread, i32(32));
@@ -2362,54 +2373,64 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
std::vector<indices_t> arg_idxs = idxs_.at(arg);
size_t n_elts = arg_idxs.size();
unsigned col_per_thread = 0;
Value* warp_i;
Value* warp_j;
if(analysis::scanline_layout* scanline = layout->to_scanline()){
Value* warp_j = nullptr;
if (analysis::scanline_layout *scanline = layout->to_scanline()) {
std::vector<int> order = layout->get_order();
unsigned mts = scanline->mts(order[0]);
shuffle_width = std::min<int>(mts, 32);
warps_per_inner = std::max<int>(mts/32, 1);
warps_per_inner = std::max<int>(mts / 32, 1);
col_per_thread = shapes[order[0]] / mts;
warp_i = udiv(warp, i32(warps_per_inner));
warp_j = urem(warp, i32(warps_per_inner));
}
else if(layout->to_mma()){
shuffle_width = 4;
} else if (layout->to_mma()) {
shuffle_width = 4;
warps_per_inner = layout->to_mma()->wpt(1);
col_per_thread = 16;
warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id;
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
}
}
assert(warp_j != nullptr);
// unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]);
//
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
Value *base = cast_shared_layout_ptr(layouts_->get(layouts_->tmp(x)),
cvt(x->get_type()->get_scalar_ty()));
Value *index_base =
with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
IntegerType::get(*ctx_, 32))
: nullptr;
// preds
Value* is_lane0 = icmp_eq(lane, i32(0));
Value* is_warp0 = icmp_eq(warp, i32(0));
Value* is_thread0 = icmp_eq(thread, i32(0));
Value* lane_j = urem(lane, i32(shuffle_width));
Value* first_lane_in_col = icmp_eq(lane_j, i32(0));
add_barrier();
// compute partial sum for each warp, and store to shared memory
for(size_t i = 0; i < n_elts/col_per_thread; i++){
Value* acc;
std::pair<Value*, Value*> acc;
// reduce within thread
for(size_t j = 0; j < col_per_thread; j++){
Value* val = arg_vals[arg_idxs[i*col_per_thread + j]];
// acc = (j == 0) ? val : do_acc(acc, val);
acc = (j == 0) ? val : do_acc(acc, val);
auto arg_idx = arg_idxs[i*col_per_thread + j];
bool is_first = j == 0;
do_acc(
acc, [&]() -> Value * { return arg_vals[arg_idx]; },
[&]() -> Value * { return arg_idx[axis]; }, is_first);
}
// reduce within warp
for(int k = shuffle_width/2 ; k > 0; k >>= 1)
acc = do_acc(acc, shfl_sync(acc, k));
for(int k = shuffle_width/2 ; k > 0; k >>= 1) {
do_acc(
acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
[&]() -> Value * { return shfl_sync(acc.second, k); }, false);
}
// store partial result to shared memory
auto x_idxs = idxs_[x][i];
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc});
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
if (with_index) {
call(st_shared_index,
{icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
}
}
add_barrier();
// at this point, partial accumulator synchronized in shared memory
@@ -2418,48 +2439,66 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value
auto x_idxs = idxs_[x][i];
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner)));
Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
for(int k = warps_per_inner/2; k > 0; k >>= 1)
acc = do_acc(acc, shfl_sync(acc, k));
vals_[x][idxs_[x][i]] = acc;
std::pair<Value*, Value*> acc;
acc.first = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
acc.second = with_index ? call(ld_shared_index, {builder_->getInt1(true),
gep(index_base, ld_off)})
: nullptr;
for (int k = warps_per_inner / 2; k > 0; k >>= 1) {
do_acc(
acc, [&]() -> Value * { return shfl_sync(acc.first, k); },
[&]() -> Value * { return shfl_sync(acc.second, k); }, false);
}
vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
}
// add_barrier();
}
void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
void generator::visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral) {
ir::value *arg = x->get_operand(0);
Type *ty = cvt(x->get_type()->get_scalar_ty());
unsigned axis = x->get_axis();
auto with_index = x->with_index();
// reduce within thread
std::map<indices_t, Value*> accs;
// index-><current reduced value, current min/max index (optional)>
std::map<indices_t, std::pair<Value*, Value*>> accs;
for(indices_t idx: idxs_.at(arg)){
indices_t pidx = idx;
pidx[axis] = i32(0);
Value *current = vals_[arg][idx];
bool is_first = accs.find(pidx) == accs.end();
accs[pidx] = is_first ? current : do_acc(accs[pidx], current);
do_acc(
accs[pidx], [&]() -> Value * { return vals_[arg][idx]; },
[&]() -> Value * { return idx[axis]; }, is_first);
};
// reduce within blocks
analysis::data_layout* layout = layouts_->get(layouts_->tmp(x));
Value *base = shared_ptr_.at(layout);
auto shape = layout->get_shape();
auto order = layout->get_order();
int space = base->getType()->getPointerAddressSpace();
Value *ptr = bit_cast(base, ptr_ty(ty, space));
auto *data_layout = layouts_->get(layouts_->tmp(x));
auto *data_ptr =
cast_shared_layout_ptr(data_layout, cvt(x->get_type()->get_scalar_ty()));
auto *index_ptr =
with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)),
IntegerType::get(*ctx_, 32))
: data_ptr;
auto shape = data_layout->get_shape();
auto order = data_layout->get_order();
Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id;
for(auto& x: accs) {
// current element being computed
Value *&acc = x.second;
std::pair<Value *, Value *> acc = x.second;
indices_t write_idx = x.first;
write_idx[axis] = lane;
// shared memory write pointer
Value *write_off = shared_off(shape, order, write_idx);
Value *write_ptr = gep(ptr, write_off);
Value *write_ptr = gep(data_ptr, write_off);
Value *index_write_ptr = gep(index_ptr, write_off);
// initialize shared memory
add_barrier();
store(acc, write_ptr);
store(acc.first, write_ptr);
if (with_index) {
store(acc.second, index_write_ptr);
}
// build result
indices_t idx(write_idx.size(), i32(0));
for(size_t i = shape[axis]/2; i > 0; i >>= 1){
@@ -2468,11 +2507,17 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Val
Value *read_msk = icmp_ult(lane, i32(i));
Value *read_off = select(read_msk, shared_off(shape, order, idx), i32(0));
Value *read_ptr = gep(write_ptr, read_off);
Value *index_read_ptr = gep(index_write_ptr, read_off);
add_barrier();
// update accumulator
acc = do_acc(acc, load(read_ptr));
do_acc(
acc, [&]() -> Value * { return load(read_ptr); },
[&]() -> Value * { return load(index_read_ptr); }, false);
add_barrier();
store(acc, write_ptr);
store(acc.first, write_ptr);
if (with_index) {
store(acc.second, index_write_ptr);
}
}
}
add_barrier();
@@ -2482,7 +2527,8 @@ void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Val
indices_t read_idx = idx;
read_idx.insert(read_idx.begin() + axis, i32(0));
Value *read_off = shared_off(shape, order, read_idx);
Value *read_ptr = gep(ptr, read_off);
Value *read_ptr =
with_index ? gep(index_ptr, read_off) : gep(data_ptr, read_off);
vals_[x][idx] = load(read_ptr);
};
}
@@ -2494,45 +2540,75 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
Type *ty = cvt(x->get_type()->get_scalar_ty());
// accumulation function
ir::reduce_inst::op_t op = x->get_op();
auto do_acc = [&](Value *x, Value *y) -> Value* {
auto do_acc_op = [&](Value *x, Value *y) -> Value* {
switch(op){
case ir::reduce_inst::ADD: return add(x, y);
case ir::reduce_inst::SUB: return sub(x, y);
case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
case ir::reduce_inst::ARGUMAX: return icmp_uge(x, y);
case ir::reduce_inst::ARGUMIN: return icmp_ule(x, y);
case ir::reduce_inst::ARGMAX: return icmp_sge(x, y);
case ir::reduce_inst::ARGMIN: return icmp_sle(x, y);
case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y);
case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y);
case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y);
case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y);
case ir::reduce_inst::FADD: return fadd(x, y);
case ir::reduce_inst::FSUB: return fsub(x, y);
case ir::reduce_inst::ARGFMAX: return fcmp_oge(x, y);
case ir::reduce_inst::ARGFMIN: return fcmp_ole(x, y);
case ir::reduce_inst::FMAX: return max_num(x, y);
case ir::reduce_inst::FMIN: return min_num(x, y);
case ir::reduce_inst::XOR: return xor_(x, y);
default: throw std::runtime_error("unreachable");
}
};
auto do_acc = [&](std::pair<Value *, Value *> &acc,
std::function<Value *()> load_value_fn,
std::function<Value *()> load_index_fn,
bool is_first) -> void {
auto *val = load_value_fn();
if (x->with_index()) {
auto *index = load_index_fn();
if (is_first) {
acc.first = val;
acc.second = index;
} else {
Value *ret = do_acc_op(acc.first, val);
acc.first = select(ret, acc.first, val);
acc.second = select(ret, acc.second, index);
}
} else {
acc.first = is_first ? val : do_acc_op(acc.first, val);
}
};
// neutral element
Value *neutral;
switch(op) {
case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::ARGUMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::ARGUMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::ARGMAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::ARGMIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break;
case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break;
case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break;
case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break;
case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break;
case ir::reduce_inst::ARGFMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
case ir::reduce_inst::ARGFMIN: neutral = ConstantFP::get(ty, INFINITY); break;
case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break;
case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break;
case ir::reduce_inst::XOR: neutral = neutral = ConstantInt::get(ty, 0); break;
case ir::reduce_inst::XOR: neutral = ConstantInt::get(ty, 0); break;
default: throw std::runtime_error("unreachable");
}
ir::value *arg = x->get_operand(0);
int cc = tgt_->as_nvidia()->sm();
analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline();
analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma();
bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis());
bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1);
if(is_coalesced_scanline || is_a100_mma)
bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x);
bool is_a100_mma = layouts_->is_a100_mma(x);
if (is_coalesced_scanline || is_a100_mma)
visit_reducend_inst_fast(x, do_acc, neutral);
else
visit_reducend_inst(x, do_acc, neutral);
@@ -2938,6 +3014,13 @@ void generator::forward_declare(ir::function* fn){
fns_[fn] = ret;
}
Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout,
Type *ty) {
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
Value *base = bit_cast(shared_ptr_.at(layout), ptr_ty(ty, addr_space));
return base;
}
void generator::visit_function(ir::function* fn) {
idxs_.clear();
vals_.clear();

View File

@@ -60,15 +60,22 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
analysis::shared_layout* a_tmp_index = layouts_->has_tmp_index(a) ? layouts_->get(layouts_->tmp_index(a))->to_shared() : nullptr;
for(ir::value* b: bs){
if(!b->get_type()->is_block_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
analysis::shared_layout* b_tmp_index = layouts_->has_tmp_index(b) ? layouts_->get(layouts_->tmp_index(b))->to_shared() : nullptr;
if(intersect_with(a_layout, b_layout) ||
intersect_with(a_layout, b_tmp) ||
intersect_with(a_layout, b_tmp_index) ||
intersect_with(a_tmp, b_layout) ||
intersect_with(a_tmp, b_tmp))
intersect_with(a_tmp, b_tmp) ||
intersect_with(a_tmp, b_tmp_index) ||
intersect_with(a_tmp_index, b_layout) ||
intersect_with(a_tmp_index, b_tmp) ||
intersect_with(a_tmp_index, b_tmp_index))
ret.insert(b);
}
}