diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 48a15c4af..e2303b990 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1285,13 +1285,35 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { // vector size int vec = 1; + Value *mask = builder_->getInt1(true); if(atom->get_type()->is_block_ty()){ + auto shape = atom->get_type()->get_block_shapes(); int ld = ords_.at(ptr)[0]; unsigned alignment = alignment_->get(ptr, ld); vec = std::min(layouts_->get(ptr)->to_scanline()->nts(ld), alignment); vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1); + // mask out inactive threads + analysis::data_layout* layout = layouts_->get(val); + auto curr_axes = a_axes_->get(val); + auto layt_axes = layout->get_axes(); + for(unsigned k = 0; k < layt_axes.size(); k++){ + unsigned ax = layt_axes.at(k); + distributed_axis dax = axes_.at(ax); + // axis is part of the original layout: thread id should be 0 + // but not the current layout + if(std::find(curr_axes.begin(), curr_axes.end(), ax) == curr_axes.end()) + mask = and_(mask, icmp_eq(dax.thread_id, i32(0))); + } + // last axis may spillover + Value *thread_id = tgt_->get_local_id(mod_, *builder_, 0); + int per_thread = 1; + for(int ax: layt_axes) { per_thread *= axes_.at(ax).contiguous; } + int numel = 1; + for(int s: layout->get_shape()) { numel *= s; } + mask = and_(mask, icmp_ult(mul(thread_id, i32(per_thread)), i32(numel))); } + for(int i = 0; i < idxs_.at(val).size(); i += vec){ auto idx = idxs_[val][i]; Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec)); @@ -1299,6 +1321,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii); Value *rmw_ptr = vals_[ptr][idx]; Value *rmw_msk = vals_[msk][idx]; + rmw_msk = and_(rmw_msk, mask); if(vec == 1) rmw_val = extract_elt(rmw_val, i32(0)); Type* ty = rmw_val->getType(); @@ -3400,20 +3423,20 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) { } void generator::visit_layout_scanline(analysis::scanline_layout* layout) { - Value* u_thread_id = tgt_->get_local_id(mod_, *builder_, 0); + Value* thread_id = tgt_->get_local_id(mod_, *builder_, 0); auto order = layout->get_order(); const auto& shape = layout->get_shape(); // Delinearize size_t dim = shape.size(); - std::vector thread_id(dim); + std::vector thread_ids(dim); for(unsigned k = 0; k < dim - 1; k++){ Constant *dim_k = i32(layout->mts(order[k])); - Value *rem = urem(u_thread_id, dim_k); - u_thread_id = udiv(u_thread_id, dim_k); - thread_id[order[k]] = rem; + Value *rem = urem(thread_id, dim_k); + thread_id = udiv(thread_id, dim_k); + thread_ids[order[k]] = rem; } Constant *dim_k = i32(layout->mts(order[dim - 1])); - thread_id[order[dim - 1]] = urem(u_thread_id, dim_k); + thread_ids[order[dim - 1]] = urem(thread_id, dim_k); // Create axes for(unsigned k = 0; k < dim; k++) { @@ -3421,15 +3444,15 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) { int mts = layout->mts(k); std::string str_k = std::to_string(k); Value *contiguous_k = i32(nts); - Value *scaled_thread_id = mul(thread_id[k], contiguous_k); + Value *scaled_thread_ids = mul(thread_ids[k], contiguous_k); unsigned per_cta = layout->shape_per_cta(k); unsigned per_thread = nts * shape[k] / per_cta; std::vector idx_list(per_thread); for(unsigned n = 0 ; n < per_thread; n++){ unsigned offset = n / nts * per_cta + n % nts; - idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); + idx_list[n] = add(scaled_thread_ids, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); } - axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]}; + axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_ids[k]}; } } diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 8b5ad3625..862ad1efe 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -15,42 +15,6 @@ namespace transform{ coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80) : align_(align), layout_(layouts), has_sm80_(has_sm80) { } - -// simplify layout conversions using the following simple rules: -// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2 -// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y)) -//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){ -// ir::value* _op = inst->get_operand(0); -// ir::instruction* op = dynamic_cast(_op); -// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma(); -// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma(); -// std::cout << 1 << std::endl; -// // i must be layout conversion instruction -// if(!mma_in && !mma_out) -// return inst; -// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2 -// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT; -// if((mma_in || mma_out) && is_op_cvt && -// (layout_->get(inst) == layout_->get(op->get_operand(0)))) -// return op->get_operand(0); -// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y)) -// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR) -// return inst; -// std::cout << 1 << std::endl; -// for(size_t i = 0; i < op->get_num_operands(); i++){ -// ir::value* arg_i = op->get_operand(i); -// builder.set_insert_point(op); -// // create new layout transform -// ir::instruction* new_arg_i = inst->clone(); -// builder.insert(new_arg_i); -// // set the right args -// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); -// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder)); -// } -// std::cout << 2 << std::endl; -// return op; -//} - void coalesce::run(ir::module &mod) { std::set invalidated; ir::builder& builder = mod.get_builder(); @@ -62,7 +26,7 @@ void coalesce::run(ir::module &mod) { if(dynamic_cast(i) || dynamic_cast(i)) if(ir::value* op = i->get_operand(1)) if(op->get_type()->is_block_ty()) - if(op->get_type()->get_tile_rank() == 2) + if(op->get_type()->get_tile_ranks1() == 2) if(invalidated.find(layout_->get(op)) == invalidated.end()) if(layout_->get(op)->to_mma()) if(dynamic_cast(i)->get_eviction_policy()==ir::io_inst::NORMAL){ @@ -78,7 +42,7 @@ void coalesce::run(ir::module &mod) { if(dynamic_cast(i) || dynamic_cast(i)) if(ir::value* op = i->get_operand(0)) if(op->get_type()->is_block_ty()) - if(op->get_type()->get_tile_rank() == 2) + if(op->get_type()->get_tile_ranks1() == 2) if(invalidated.find(layout_->get(op)) == invalidated.end()) if(layout_->get(op)->to_mma()){ ir::instruction* new_op = ir::cvt_layout_inst::create(op); @@ -91,7 +55,7 @@ void coalesce::run(ir::module &mod) { // uncoalesce after load if(auto x = dynamic_cast(i)) if(x->get_type()->is_block_ty()) - if(x->get_type()->get_tile_rank()==2) + if(x->get_type()->get_tile_ranks1()==2) if(layout_->get(x)->to_mma()) if(!has_sm80_ || dynamic_cast(i)->get_eviction_policy()==ir::io_inst::NORMAL){ builder.set_insert_point_after(x); @@ -111,9 +75,11 @@ void coalesce::run(ir::module &mod) { auto out_contig = align_->contiguous(ptr); auto val_inst = dynamic_cast(val); if(!val_inst) - break; + continue; if(dynamic_cast(val)) - break; + continue; + if(!val->get_type()->is_block_ty() || val->get_type()->get_tile_ranks1()==1) + continue; std::vector in_contig; std::vector queue = {val_inst}; std::set seen; diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 92c854f06..d032d1e39 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -532,6 +532,29 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) +@pytest.mark.parametrize("axis", [0, 1]) +def test_tensor_atomic_rmw(axis, device="cuda"): + shape0, shape1 = 8, 8 + # triton kernel + + @triton.jit + def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + z = tl.sum(x, axis=AXIS) + tl.atomic_add(Z + off0, z) + rs = RandomState(17) + x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) + # reference result + z_ref = np.sum(x, axis=axis) + # triton result + x_tri = to_triton(x, device=device) + z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device) + kernel[(1,)](z_tri, x_tri, axis, shape0, shape1) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + + def test_atomic_cas(): # 1. make sure that atomic_cas changes the original value (Lock) @triton.jit diff --git a/python/triton/language/core.py b/python/triton/language/core.py index ee99aab2e..cc0db5566 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -370,6 +370,17 @@ class constexpr: def __call__(self, *args, **kwds): return self.value(*args, **kwds) + def to(self, dtype, bitcast=False, _builder=None): + if dtype in [float8, float16, bfloat16]: + raise ValueError("floating point constexpr must be float64") + if dtype.is_int(): + ret_ty = int + elif dtype.is_bool(): + ret_ty = bool + elif dtype.is_floating(): + ret_ty = float + return constexpr(ret_ty(self.value)) + class tensor: # infer dtype from ir type