[BACKEND] Fix some bugs (atomics, a segfault...) (#577)
This should fix #558 , #573 and #574
This commit is contained in:
@@ -1285,13 +1285,35 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
|
||||
// vector size
|
||||
int vec = 1;
|
||||
Value *mask = builder_->getInt1(true);
|
||||
if(atom->get_type()->is_block_ty()){
|
||||
auto shape = atom->get_type()->get_block_shapes();
|
||||
int ld = ords_.at(ptr)[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
|
||||
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1);
|
||||
// mask out inactive threads
|
||||
analysis::data_layout* layout = layouts_->get(val);
|
||||
auto curr_axes = a_axes_->get(val);
|
||||
auto layt_axes = layout->get_axes();
|
||||
for(unsigned k = 0; k < layt_axes.size(); k++){
|
||||
unsigned ax = layt_axes.at(k);
|
||||
distributed_axis dax = axes_.at(ax);
|
||||
// axis is part of the original layout: thread id should be 0
|
||||
// but not the current layout
|
||||
if(std::find(curr_axes.begin(), curr_axes.end(), ax) == curr_axes.end())
|
||||
mask = and_(mask, icmp_eq(dax.thread_id, i32(0)));
|
||||
}
|
||||
// last axis may spillover
|
||||
Value *thread_id = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
int per_thread = 1;
|
||||
for(int ax: layt_axes) { per_thread *= axes_.at(ax).contiguous; }
|
||||
int numel = 1;
|
||||
for(int s: layout->get_shape()) { numel *= s; }
|
||||
mask = and_(mask, icmp_ult(mul(thread_id, i32(per_thread)), i32(numel)));
|
||||
}
|
||||
|
||||
|
||||
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
||||
auto idx = idxs_[val][i];
|
||||
Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec));
|
||||
@@ -1299,6 +1321,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii);
|
||||
Value *rmw_ptr = vals_[ptr][idx];
|
||||
Value *rmw_msk = vals_[msk][idx];
|
||||
rmw_msk = and_(rmw_msk, mask);
|
||||
if(vec == 1)
|
||||
rmw_val = extract_elt(rmw_val, i32(0));
|
||||
Type* ty = rmw_val->getType();
|
||||
@@ -3400,20 +3423,20 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
||||
Value* u_thread_id = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value* thread_id = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
auto order = layout->get_order();
|
||||
const auto& shape = layout->get_shape();
|
||||
// Delinearize
|
||||
size_t dim = shape.size();
|
||||
std::vector<Value*> thread_id(dim);
|
||||
std::vector<Value*> thread_ids(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = i32(layout->mts(order[k]));
|
||||
Value *rem = urem(u_thread_id, dim_k);
|
||||
u_thread_id = udiv(u_thread_id, dim_k);
|
||||
thread_id[order[k]] = rem;
|
||||
Value *rem = urem(thread_id, dim_k);
|
||||
thread_id = udiv(thread_id, dim_k);
|
||||
thread_ids[order[k]] = rem;
|
||||
}
|
||||
Constant *dim_k = i32(layout->mts(order[dim - 1]));
|
||||
thread_id[order[dim - 1]] = urem(u_thread_id, dim_k);
|
||||
thread_ids[order[dim - 1]] = urem(thread_id, dim_k);
|
||||
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
@@ -3421,15 +3444,15 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
||||
int mts = layout->mts(k);
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = i32(nts);
|
||||
Value *scaled_thread_id = mul(thread_id[k], contiguous_k);
|
||||
Value *scaled_thread_ids = mul(thread_ids[k], contiguous_k);
|
||||
unsigned per_cta = layout->shape_per_cta(k);
|
||||
unsigned per_thread = nts * shape[k] / per_cta;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts * per_cta + n % nts;
|
||||
idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
idx_list[n] = add(scaled_thread_ids, i32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
|
||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_ids[k]};
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user