Revert "[BACKEND] Various bug fixes; making reductions faster (#533)".

This is a more stable commit that produce bitwise identical code to earlier
versions. Using commits after this one may lead to slightly different numerics
This commit is contained in:
Philippe Tillet
2022-06-03 11:36:06 -07:00
parent efa04cac1f
commit a60374a597
11 changed files with 65 additions and 173 deletions

View File

@@ -88,7 +88,6 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getBFloatTy()
#define f32_ty builder_->getFloatTy()
#define i1_ty builder_->getInt1Ty()
#define i8_ty builder_->getInt8Ty()
#define i16_ty builder_->getInt16Ty()
#define i32_ty builder_->getInt32Ty()
@@ -737,9 +736,6 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) {
* \brief Code Generation for a (synchronous) `load`
*/
void generator::visit_load_inst(ir::load_inst* x){
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, *builder_, 0);
ir::value *op = x->get_pointer_operand();
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
@@ -779,9 +775,6 @@ void generator::visit_load_inst(ir::load_inst* x){
in_off = 0;
}
Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue();
// if(!op->get_type()->is_block_ty()){
// pred = builder_->CreateAnd(pred, icmp_eq(tid, i32(0)));
// }
Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr;
size_t nbits = dtsize*8;
// pack sub-words (< 32/64bits) into words
@@ -885,18 +878,6 @@ void generator::visit_load_inst(ir::load_inst* x){
Value *_ret = call(inlineAsm, args);
// if(!op->get_type()->is_block_ty()){
// Value* cond = icmp_eq(tid, i32(0));
// Value* shptr = bit_cast(shmem_, ptr_ty(_ret->getType(), 3));
// Instruction* bar = add_barrier();
// Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, bar, false);
// builder_->SetInsertPoint(term);
// store(_ret, shptr);
// builder_->SetInsertPoint(bar->getParent());
// _ret = load(shptr);
// add_barrier();
// }
// ---
// extract and store return values
// ---
@@ -2052,12 +2033,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
// create mma & unpack result, m, n, k are offsets in mat
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
unsigned cols_per_thread = num_rep_n * 2;
unsigned cols_per_thread = num_rep_m * 2;
std::vector<size_t> idx = {
(m + 0)*cols_per_thread + (n*2 + 0),
(m + 0)*cols_per_thread + (n*2 + 1),
(m + 1)*cols_per_thread + (n*2 + 0),
(m + 1)*cols_per_thread + (n*2 + 1)
(m + 0) + (n*2 + 0)*cols_per_thread,
(m + 0) + (n*2 + 1)*cols_per_thread,
(m + 1) + (n*2 + 0)*cols_per_thread,
(m + 1) + (n*2 + 1)*cols_per_thread
};
Value *nc = call(mma_ty, mma_fn,
{ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}],
@@ -2335,93 +2316,62 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral){
//
ir::value *arg = x->get_operand(0);
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
analysis::scanline_layout* layout = layouts_->get(arg)->to_scanline();
std::vector<unsigned> shapes = layout->get_shape();
Type* sca_ty = cvt(arg->get_type()->get_scalar_ty());
size_t n_bits = sca_ty->getPrimitiveSizeInBits();
std::string n_bits_str = std::to_string(n_bits);
std::string cst = (n_bits == 64) ? "l" : "r";
FunctionType *st_shared_ty = FunctionType::get(void_ty, {i1_ty, ptr_ty(sca_ty, 3), sca_ty}, false);
InlineAsm *st_shared = InlineAsm::get(st_shared_ty, "@$0 st.shared.b" + n_bits_str + " [$1], $2;", "b," + cst + "," + cst, true);
FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false);
InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true);
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value* warp = udiv(thread, i32(32));
Value* lane = urem(thread, i32(32));
unsigned shuffle_width = 0;
unsigned warps_per_inner = 0;
auto arg_vals = vals_.at(arg);
std::vector<indices_t> arg_idxs = idxs_.at(arg);
size_t n_elts = arg_idxs.size();
unsigned col_per_thread;
Value* warp_i;
Value* warp_j;
if(analysis::scanline_layout* scanline = layout->to_scanline()){
std::vector<int> order = layout->get_order();
unsigned mts = scanline->mts(order[0]);
shuffle_width = std::min<int>(mts, 32);
warps_per_inner = std::max<int>(mts/32, 1);
col_per_thread = shapes[order[0]] / mts;
warp_i = udiv(warp, i32(warps_per_inner));
warp_j = urem(warp, i32(warps_per_inner));
}
else if(layout->to_mma()){
shuffle_width = 4;
warps_per_inner = layout->to_mma()->wpt(1);
col_per_thread = 16;
warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id;
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
}
// unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]);
std::vector<int> order = layout->get_order();
unsigned mts = layout->mts(order[0]);
unsigned nts = layout->nts(order[0]);
unsigned col_per_thread = shapes[order[0]] / mts;
auto idxs = idxs_.at(arg);
size_t n_elts = idxs.size();
//
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
// preds
Value* is_lane0 = icmp_eq(lane, i32(0));
Value* is_warp0 = icmp_eq(warp, i32(0));
Value* is_thread0 = icmp_eq(thread, i32(0));
Value* lane_j = urem(lane, i32(shuffle_width));
Value* first_lane_in_col = icmp_eq(lane_j, i32(0));
add_barrier();
// compute partial sum for each warp, and store to shared memory
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value* warp = udiv(thread, i32(32));
Value* lane = urem(thread, i32(32));
size_t warps_per_inner = std::max<int>(mts/32, 1);
Value* warp_i = udiv(warp, i32(warps_per_inner));
unsigned row_per_thread = std::max<int>(32/mts, 1);
for(size_t i = 0; i < n_elts/col_per_thread; i++){
Value* acc;
// reduce within thread
for(size_t j = 0; j < col_per_thread; j++){
Value* val = arg_vals[arg_idxs[i*col_per_thread + j]];
// acc = (j == 0) ? val : do_acc(acc, val);
Value* val = vals_[arg][idxs[i*col_per_thread + j]];
acc = (j == 0) ? val : do_acc(acc, val);
}
// reduce within warp
for(int k = shuffle_width/2 ; k > 0; k >>= 1)
for(int k = std::min<int>(mts, 32)/2 ; k > 0; k >>= 1)
acc = do_acc(acc, shfl_sync(acc, k));
// store partial result to shared memory
auto x_idxs = idxs_[x][i];
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc});
// store warp result in shared memory
Value* ret = acc;
if(mts >= 32){
add_barrier();
store(neutral, gep(base, lane));
add_barrier();
store(acc, gep(base, warp));
add_barrier();
// reduce across warps
Value *cond = icmp_eq(warp, i32(0));
Instruction *barrier = add_barrier();
builder_->SetInsertPoint(barrier->getParent());
Instruction* dummy = builder_->CreateRet(nullptr);
Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false);
dummy->removeFromParent();
builder_->SetInsertPoint(term);
ret = load(gep(base, thread));
for(int k = (mts/32)/2; k > 0; k >>= 1){
Value *current = shfl_sync(ret, k);
ret = do_acc(ret, current);
}
store(ret, gep(base, thread));
builder_->SetInsertPoint(barrier->getParent());
ret = load(gep(base, warp));
}
vals_[x][idxs_[x][i]] = ret;
}
add_barrier();
// at this point, partial accumulator synchronized in shared memory
// Just need to reduce `warp_per_inner` numbers in shared memory
for(size_t i = 0; i < n_elts/col_per_thread; i++){
auto x_idxs = idxs_[x][i];
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner)));
Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
for(int k = warps_per_inner/2; k > 0; k >>= 1)
acc = do_acc(acc, shfl_sync(acc, k));
vals_[x][idxs_[x][i]] = acc;
}
// add_barrier();
}
void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
@@ -2521,12 +2471,8 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
default: throw std::runtime_error("unreachable");
}
ir::value *arg = x->get_operand(0);
int cc = tgt_->as_nvidia()->sm();
analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline();
analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma();
bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis());
bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1);
if(is_coalesced_scanline || is_a100_mma)
if(scanline && scanline->get_order()[0] == x->get_axis())
visit_reducend_inst_fast(x, do_acc, neutral);
else
visit_reducend_inst(x, do_acc, neutral);
@@ -2719,12 +2665,12 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
unsigned in_vec = 1;
ir::value *arg = cts->get_operand(0);
analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared();
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->get_order();
auto in_order = in_layout->get_order();
// tiles
if(out_order == in_order)
in_vec = in_layout->contig_per_thread(in_order[0]);
in_vec = in_layout->nts(in_order[0]);
int out_vec = swizzle_->get_vec(out_layout);
int min_vec = std::min<int>(out_vec, in_vec);
int s = std::max<int>(out_vec / in_vec, 1);
@@ -2732,11 +2678,8 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
int per_phase = swizzle_->get_per_phase(out_layout);
int max_phase = swizzle_->get_max_phase(out_layout);
//
int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]);
int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]);
int in_ld = in_layout->get_shape()[in_order[0]] / mts_0;
int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
BasicBlock* CurrBB = builder_->GetInsertBlock();
@@ -2757,8 +2700,8 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
// input ptr info
int id_0 = id % (in_ld/min_vec);
int id_1 = id / (in_ld/min_vec);
int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]);
int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]);
int off = (off_1*shapes[in_order[0]] + off_0);
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
if(ptrs.find(key) == ptrs.end()){
@@ -3083,7 +3026,8 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
else{
/* warp offset */
Value *warp_0 = urem(warp, i32(layout->wpt(0)));
Value *warp_1 = urem(udiv(warp, i32(layout->wpt(0))), i32(layout->wpt(1)));
Value *warp_12 = udiv(warp, i32(layout->wpt(0)));
Value *warp_1 = urem(warp_12, i32(layout->wpt(1)));
Value *off_warp_m = mul(warp_0, i32(layout->spw(0)));
Value *off_warp_n = mul(warp_1, i32(layout->spw(1)));
Value *off_lane_m = urem(lane, _16);
@@ -3208,9 +3152,7 @@ void generator::visit_basic_block(ir::basic_block * block) {
BasicBlock *parent = bbs_[block];
builder_->SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list()){
// i->print(std::cout);
visit_value(i);
// std::cout << "done" << std::endl;
}
// Update ir bb -> llvm bb mapping
bbs_[block] = builder_->GetInsertBlock();