[BACKEND] Fix some bugs (atomics, a segfault...) (#577)
This should fix #558 , #573 and #574
This commit is contained in:
@@ -1285,13 +1285,35 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
|
||||
// vector size
|
||||
int vec = 1;
|
||||
Value *mask = builder_->getInt1(true);
|
||||
if(atom->get_type()->is_block_ty()){
|
||||
auto shape = atom->get_type()->get_block_shapes();
|
||||
int ld = ords_.at(ptr)[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
|
||||
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1);
|
||||
// mask out inactive threads
|
||||
analysis::data_layout* layout = layouts_->get(val);
|
||||
auto curr_axes = a_axes_->get(val);
|
||||
auto layt_axes = layout->get_axes();
|
||||
for(unsigned k = 0; k < layt_axes.size(); k++){
|
||||
unsigned ax = layt_axes.at(k);
|
||||
distributed_axis dax = axes_.at(ax);
|
||||
// axis is part of the original layout: thread id should be 0
|
||||
// but not the current layout
|
||||
if(std::find(curr_axes.begin(), curr_axes.end(), ax) == curr_axes.end())
|
||||
mask = and_(mask, icmp_eq(dax.thread_id, i32(0)));
|
||||
}
|
||||
// last axis may spillover
|
||||
Value *thread_id = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
int per_thread = 1;
|
||||
for(int ax: layt_axes) { per_thread *= axes_.at(ax).contiguous; }
|
||||
int numel = 1;
|
||||
for(int s: layout->get_shape()) { numel *= s; }
|
||||
mask = and_(mask, icmp_ult(mul(thread_id, i32(per_thread)), i32(numel)));
|
||||
}
|
||||
|
||||
|
||||
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
||||
auto idx = idxs_[val][i];
|
||||
Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec));
|
||||
@@ -1299,6 +1321,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
||||
rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii);
|
||||
Value *rmw_ptr = vals_[ptr][idx];
|
||||
Value *rmw_msk = vals_[msk][idx];
|
||||
rmw_msk = and_(rmw_msk, mask);
|
||||
if(vec == 1)
|
||||
rmw_val = extract_elt(rmw_val, i32(0));
|
||||
Type* ty = rmw_val->getType();
|
||||
@@ -3400,20 +3423,20 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
||||
Value* u_thread_id = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value* thread_id = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
auto order = layout->get_order();
|
||||
const auto& shape = layout->get_shape();
|
||||
// Delinearize
|
||||
size_t dim = shape.size();
|
||||
std::vector<Value*> thread_id(dim);
|
||||
std::vector<Value*> thread_ids(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = i32(layout->mts(order[k]));
|
||||
Value *rem = urem(u_thread_id, dim_k);
|
||||
u_thread_id = udiv(u_thread_id, dim_k);
|
||||
thread_id[order[k]] = rem;
|
||||
Value *rem = urem(thread_id, dim_k);
|
||||
thread_id = udiv(thread_id, dim_k);
|
||||
thread_ids[order[k]] = rem;
|
||||
}
|
||||
Constant *dim_k = i32(layout->mts(order[dim - 1]));
|
||||
thread_id[order[dim - 1]] = urem(u_thread_id, dim_k);
|
||||
thread_ids[order[dim - 1]] = urem(thread_id, dim_k);
|
||||
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
@@ -3421,15 +3444,15 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
||||
int mts = layout->mts(k);
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = i32(nts);
|
||||
Value *scaled_thread_id = mul(thread_id[k], contiguous_k);
|
||||
Value *scaled_thread_ids = mul(thread_ids[k], contiguous_k);
|
||||
unsigned per_cta = layout->shape_per_cta(k);
|
||||
unsigned per_thread = nts * shape[k] / per_cta;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts * per_cta + n % nts;
|
||||
idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
idx_list[n] = add(scaled_thread_ids, i32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
|
||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_ids[k]};
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -15,42 +15,6 @@ namespace transform{
|
||||
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80)
|
||||
: align_(align), layout_(layouts), has_sm80_(has_sm80) { }
|
||||
|
||||
|
||||
// simplify layout conversions using the following simple rules:
|
||||
// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
|
||||
// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y))
|
||||
//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){
|
||||
// ir::value* _op = inst->get_operand(0);
|
||||
// ir::instruction* op = dynamic_cast<ir::instruction*>(_op);
|
||||
// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma();
|
||||
// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma();
|
||||
// std::cout << 1 << std::endl;
|
||||
// // i must be layout conversion instruction
|
||||
// if(!mma_in && !mma_out)
|
||||
// return inst;
|
||||
// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
|
||||
// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT;
|
||||
// if((mma_in || mma_out) && is_op_cvt &&
|
||||
// (layout_->get(inst) == layout_->get(op->get_operand(0))))
|
||||
// return op->get_operand(0);
|
||||
// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y))
|
||||
// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR)
|
||||
// return inst;
|
||||
// std::cout << 1 << std::endl;
|
||||
// for(size_t i = 0; i < op->get_num_operands(); i++){
|
||||
// ir::value* arg_i = op->get_operand(i);
|
||||
// builder.set_insert_point(op);
|
||||
// // create new layout transform
|
||||
// ir::instruction* new_arg_i = inst->clone();
|
||||
// builder.insert(new_arg_i);
|
||||
// // set the right args
|
||||
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
|
||||
// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder));
|
||||
// }
|
||||
// std::cout << 2 << std::endl;
|
||||
// return op;
|
||||
//}
|
||||
|
||||
void coalesce::run(ir::module &mod) {
|
||||
std::set<analysis::data_layout*> invalidated;
|
||||
ir::builder& builder = mod.get_builder();
|
||||
@@ -62,7 +26,7 @@ void coalesce::run(ir::module &mod) {
|
||||
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
|
||||
if(ir::value* op = i->get_operand(1))
|
||||
if(op->get_type()->is_block_ty())
|
||||
if(op->get_type()->get_tile_rank() == 2)
|
||||
if(op->get_type()->get_tile_ranks1() == 2)
|
||||
if(invalidated.find(layout_->get(op)) == invalidated.end())
|
||||
if(layout_->get(op)->to_mma())
|
||||
if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
|
||||
@@ -78,7 +42,7 @@ void coalesce::run(ir::module &mod) {
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i))
|
||||
if(ir::value* op = i->get_operand(0))
|
||||
if(op->get_type()->is_block_ty())
|
||||
if(op->get_type()->get_tile_rank() == 2)
|
||||
if(op->get_type()->get_tile_ranks1() == 2)
|
||||
if(invalidated.find(layout_->get(op)) == invalidated.end())
|
||||
if(layout_->get(op)->to_mma()){
|
||||
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
|
||||
@@ -91,7 +55,7 @@ void coalesce::run(ir::module &mod) {
|
||||
// uncoalesce after load
|
||||
if(auto x = dynamic_cast<ir::load_inst*>(i))
|
||||
if(x->get_type()->is_block_ty())
|
||||
if(x->get_type()->get_tile_rank()==2)
|
||||
if(x->get_type()->get_tile_ranks1()==2)
|
||||
if(layout_->get(x)->to_mma())
|
||||
if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
|
||||
builder.set_insert_point_after(x);
|
||||
@@ -111,9 +75,11 @@ void coalesce::run(ir::module &mod) {
|
||||
auto out_contig = align_->contiguous(ptr);
|
||||
auto val_inst = dynamic_cast<ir::instruction*>(val);
|
||||
if(!val_inst)
|
||||
break;
|
||||
continue;
|
||||
if(dynamic_cast<ir::cvt_layout_inst*>(val))
|
||||
break;
|
||||
continue;
|
||||
if(!val->get_type()->is_block_ty() || val->get_type()->get_tile_ranks1()==1)
|
||||
continue;
|
||||
std::vector<unsigned> in_contig;
|
||||
std::vector<ir::instruction*> queue = {val_inst};
|
||||
std::set<ir::instruction*> seen;
|
||||
|
Reference in New Issue
Block a user