[BACKEND] Fix some bugs (atomics, a segfault...) (#577)
This should fix #558 , #573 and #574
This commit is contained in:
@@ -1285,13 +1285,35 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
|||||||
|
|
||||||
// vector size
|
// vector size
|
||||||
int vec = 1;
|
int vec = 1;
|
||||||
|
Value *mask = builder_->getInt1(true);
|
||||||
if(atom->get_type()->is_block_ty()){
|
if(atom->get_type()->is_block_ty()){
|
||||||
|
auto shape = atom->get_type()->get_block_shapes();
|
||||||
int ld = ords_.at(ptr)[0];
|
int ld = ords_.at(ptr)[0];
|
||||||
unsigned alignment = alignment_->get(ptr, ld);
|
unsigned alignment = alignment_->get(ptr, ld);
|
||||||
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
|
vec = std::min<int>(layouts_->get(ptr)->to_scanline()->nts(ld), alignment);
|
||||||
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1);
|
vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1);
|
||||||
|
// mask out inactive threads
|
||||||
|
analysis::data_layout* layout = layouts_->get(val);
|
||||||
|
auto curr_axes = a_axes_->get(val);
|
||||||
|
auto layt_axes = layout->get_axes();
|
||||||
|
for(unsigned k = 0; k < layt_axes.size(); k++){
|
||||||
|
unsigned ax = layt_axes.at(k);
|
||||||
|
distributed_axis dax = axes_.at(ax);
|
||||||
|
// axis is part of the original layout: thread id should be 0
|
||||||
|
// but not the current layout
|
||||||
|
if(std::find(curr_axes.begin(), curr_axes.end(), ax) == curr_axes.end())
|
||||||
|
mask = and_(mask, icmp_eq(dax.thread_id, i32(0)));
|
||||||
|
}
|
||||||
|
// last axis may spillover
|
||||||
|
Value *thread_id = tgt_->get_local_id(mod_, *builder_, 0);
|
||||||
|
int per_thread = 1;
|
||||||
|
for(int ax: layt_axes) { per_thread *= axes_.at(ax).contiguous; }
|
||||||
|
int numel = 1;
|
||||||
|
for(int s: layout->get_shape()) { numel *= s; }
|
||||||
|
mask = and_(mask, icmp_ult(mul(thread_id, i32(per_thread)), i32(numel)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
for(int i = 0; i < idxs_.at(val).size(); i += vec){
|
||||||
auto idx = idxs_[val][i];
|
auto idx = idxs_[val][i];
|
||||||
Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec));
|
Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec));
|
||||||
@@ -1299,6 +1321,7 @@ void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) {
|
|||||||
rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii);
|
rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii);
|
||||||
Value *rmw_ptr = vals_[ptr][idx];
|
Value *rmw_ptr = vals_[ptr][idx];
|
||||||
Value *rmw_msk = vals_[msk][idx];
|
Value *rmw_msk = vals_[msk][idx];
|
||||||
|
rmw_msk = and_(rmw_msk, mask);
|
||||||
if(vec == 1)
|
if(vec == 1)
|
||||||
rmw_val = extract_elt(rmw_val, i32(0));
|
rmw_val = extract_elt(rmw_val, i32(0));
|
||||||
Type* ty = rmw_val->getType();
|
Type* ty = rmw_val->getType();
|
||||||
@@ -3400,20 +3423,20 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
||||||
Value* u_thread_id = tgt_->get_local_id(mod_, *builder_, 0);
|
Value* thread_id = tgt_->get_local_id(mod_, *builder_, 0);
|
||||||
auto order = layout->get_order();
|
auto order = layout->get_order();
|
||||||
const auto& shape = layout->get_shape();
|
const auto& shape = layout->get_shape();
|
||||||
// Delinearize
|
// Delinearize
|
||||||
size_t dim = shape.size();
|
size_t dim = shape.size();
|
||||||
std::vector<Value*> thread_id(dim);
|
std::vector<Value*> thread_ids(dim);
|
||||||
for(unsigned k = 0; k < dim - 1; k++){
|
for(unsigned k = 0; k < dim - 1; k++){
|
||||||
Constant *dim_k = i32(layout->mts(order[k]));
|
Constant *dim_k = i32(layout->mts(order[k]));
|
||||||
Value *rem = urem(u_thread_id, dim_k);
|
Value *rem = urem(thread_id, dim_k);
|
||||||
u_thread_id = udiv(u_thread_id, dim_k);
|
thread_id = udiv(thread_id, dim_k);
|
||||||
thread_id[order[k]] = rem;
|
thread_ids[order[k]] = rem;
|
||||||
}
|
}
|
||||||
Constant *dim_k = i32(layout->mts(order[dim - 1]));
|
Constant *dim_k = i32(layout->mts(order[dim - 1]));
|
||||||
thread_id[order[dim - 1]] = urem(u_thread_id, dim_k);
|
thread_ids[order[dim - 1]] = urem(thread_id, dim_k);
|
||||||
|
|
||||||
// Create axes
|
// Create axes
|
||||||
for(unsigned k = 0; k < dim; k++) {
|
for(unsigned k = 0; k < dim; k++) {
|
||||||
@@ -3421,15 +3444,15 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
|||||||
int mts = layout->mts(k);
|
int mts = layout->mts(k);
|
||||||
std::string str_k = std::to_string(k);
|
std::string str_k = std::to_string(k);
|
||||||
Value *contiguous_k = i32(nts);
|
Value *contiguous_k = i32(nts);
|
||||||
Value *scaled_thread_id = mul(thread_id[k], contiguous_k);
|
Value *scaled_thread_ids = mul(thread_ids[k], contiguous_k);
|
||||||
unsigned per_cta = layout->shape_per_cta(k);
|
unsigned per_cta = layout->shape_per_cta(k);
|
||||||
unsigned per_thread = nts * shape[k] / per_cta;
|
unsigned per_thread = nts * shape[k] / per_cta;
|
||||||
std::vector<Value*> idx_list(per_thread);
|
std::vector<Value*> idx_list(per_thread);
|
||||||
for(unsigned n = 0 ; n < per_thread; n++){
|
for(unsigned n = 0 ; n < per_thread; n++){
|
||||||
unsigned offset = n / nts * per_cta + n % nts;
|
unsigned offset = n / nts * per_cta + n % nts;
|
||||||
idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
idx_list[n] = add(scaled_thread_ids, i32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||||
}
|
}
|
||||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
|
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_ids[k]};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -15,42 +15,6 @@ namespace transform{
|
|||||||
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80)
|
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80)
|
||||||
: align_(align), layout_(layouts), has_sm80_(has_sm80) { }
|
: align_(align), layout_(layouts), has_sm80_(has_sm80) { }
|
||||||
|
|
||||||
|
|
||||||
// simplify layout conversions using the following simple rules:
|
|
||||||
// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
|
|
||||||
// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y))
|
|
||||||
//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){
|
|
||||||
// ir::value* _op = inst->get_operand(0);
|
|
||||||
// ir::instruction* op = dynamic_cast<ir::instruction*>(_op);
|
|
||||||
// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma();
|
|
||||||
// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma();
|
|
||||||
// std::cout << 1 << std::endl;
|
|
||||||
// // i must be layout conversion instruction
|
|
||||||
// if(!mma_in && !mma_out)
|
|
||||||
// return inst;
|
|
||||||
// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
|
|
||||||
// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT;
|
|
||||||
// if((mma_in || mma_out) && is_op_cvt &&
|
|
||||||
// (layout_->get(inst) == layout_->get(op->get_operand(0))))
|
|
||||||
// return op->get_operand(0);
|
|
||||||
// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y))
|
|
||||||
// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR)
|
|
||||||
// return inst;
|
|
||||||
// std::cout << 1 << std::endl;
|
|
||||||
// for(size_t i = 0; i < op->get_num_operands(); i++){
|
|
||||||
// ir::value* arg_i = op->get_operand(i);
|
|
||||||
// builder.set_insert_point(op);
|
|
||||||
// // create new layout transform
|
|
||||||
// ir::instruction* new_arg_i = inst->clone();
|
|
||||||
// builder.insert(new_arg_i);
|
|
||||||
// // set the right args
|
|
||||||
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
|
|
||||||
// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder));
|
|
||||||
// }
|
|
||||||
// std::cout << 2 << std::endl;
|
|
||||||
// return op;
|
|
||||||
//}
|
|
||||||
|
|
||||||
void coalesce::run(ir::module &mod) {
|
void coalesce::run(ir::module &mod) {
|
||||||
std::set<analysis::data_layout*> invalidated;
|
std::set<analysis::data_layout*> invalidated;
|
||||||
ir::builder& builder = mod.get_builder();
|
ir::builder& builder = mod.get_builder();
|
||||||
@@ -62,7 +26,7 @@ void coalesce::run(ir::module &mod) {
|
|||||||
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
|
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
|
||||||
if(ir::value* op = i->get_operand(1))
|
if(ir::value* op = i->get_operand(1))
|
||||||
if(op->get_type()->is_block_ty())
|
if(op->get_type()->is_block_ty())
|
||||||
if(op->get_type()->get_tile_rank() == 2)
|
if(op->get_type()->get_tile_ranks1() == 2)
|
||||||
if(invalidated.find(layout_->get(op)) == invalidated.end())
|
if(invalidated.find(layout_->get(op)) == invalidated.end())
|
||||||
if(layout_->get(op)->to_mma())
|
if(layout_->get(op)->to_mma())
|
||||||
if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
|
if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
|
||||||
@@ -78,7 +42,7 @@ void coalesce::run(ir::module &mod) {
|
|||||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i))
|
if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i))
|
||||||
if(ir::value* op = i->get_operand(0))
|
if(ir::value* op = i->get_operand(0))
|
||||||
if(op->get_type()->is_block_ty())
|
if(op->get_type()->is_block_ty())
|
||||||
if(op->get_type()->get_tile_rank() == 2)
|
if(op->get_type()->get_tile_ranks1() == 2)
|
||||||
if(invalidated.find(layout_->get(op)) == invalidated.end())
|
if(invalidated.find(layout_->get(op)) == invalidated.end())
|
||||||
if(layout_->get(op)->to_mma()){
|
if(layout_->get(op)->to_mma()){
|
||||||
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
|
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
|
||||||
@@ -91,7 +55,7 @@ void coalesce::run(ir::module &mod) {
|
|||||||
// uncoalesce after load
|
// uncoalesce after load
|
||||||
if(auto x = dynamic_cast<ir::load_inst*>(i))
|
if(auto x = dynamic_cast<ir::load_inst*>(i))
|
||||||
if(x->get_type()->is_block_ty())
|
if(x->get_type()->is_block_ty())
|
||||||
if(x->get_type()->get_tile_rank()==2)
|
if(x->get_type()->get_tile_ranks1()==2)
|
||||||
if(layout_->get(x)->to_mma())
|
if(layout_->get(x)->to_mma())
|
||||||
if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
|
if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
|
||||||
builder.set_insert_point_after(x);
|
builder.set_insert_point_after(x);
|
||||||
@@ -111,9 +75,11 @@ void coalesce::run(ir::module &mod) {
|
|||||||
auto out_contig = align_->contiguous(ptr);
|
auto out_contig = align_->contiguous(ptr);
|
||||||
auto val_inst = dynamic_cast<ir::instruction*>(val);
|
auto val_inst = dynamic_cast<ir::instruction*>(val);
|
||||||
if(!val_inst)
|
if(!val_inst)
|
||||||
break;
|
continue;
|
||||||
if(dynamic_cast<ir::cvt_layout_inst*>(val))
|
if(dynamic_cast<ir::cvt_layout_inst*>(val))
|
||||||
break;
|
continue;
|
||||||
|
if(!val->get_type()->is_block_ty() || val->get_type()->get_tile_ranks1()==1)
|
||||||
|
continue;
|
||||||
std::vector<unsigned> in_contig;
|
std::vector<unsigned> in_contig;
|
||||||
std::vector<ir::instruction*> queue = {val_inst};
|
std::vector<ir::instruction*> queue = {val_inst};
|
||||||
std::set<ir::instruction*> seen;
|
std::set<ir::instruction*> seen;
|
||||||
|
@@ -532,6 +532,29 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
|||||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("axis", [0, 1])
|
||||||
|
def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||||
|
shape0, shape1 = 8, 8
|
||||||
|
# triton kernel
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||||
|
off0 = tl.arange(0, SHAPE0)
|
||||||
|
off1 = tl.arange(0, SHAPE1)
|
||||||
|
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
|
||||||
|
z = tl.sum(x, axis=AXIS)
|
||||||
|
tl.atomic_add(Z + off0, z)
|
||||||
|
rs = RandomState(17)
|
||||||
|
x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs)
|
||||||
|
# reference result
|
||||||
|
z_ref = np.sum(x, axis=axis)
|
||||||
|
# triton result
|
||||||
|
x_tri = to_triton(x, device=device)
|
||||||
|
z_tri = to_triton(np.zeros((shape0,), dtype="float32"), device=device)
|
||||||
|
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||||
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
def test_atomic_cas():
|
def test_atomic_cas():
|
||||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
@@ -370,6 +370,17 @@ class constexpr:
|
|||||||
def __call__(self, *args, **kwds):
|
def __call__(self, *args, **kwds):
|
||||||
return self.value(*args, **kwds)
|
return self.value(*args, **kwds)
|
||||||
|
|
||||||
|
def to(self, dtype, bitcast=False, _builder=None):
|
||||||
|
if dtype in [float8, float16, bfloat16]:
|
||||||
|
raise ValueError("floating point constexpr must be float64")
|
||||||
|
if dtype.is_int():
|
||||||
|
ret_ty = int
|
||||||
|
elif dtype.is_bool():
|
||||||
|
ret_ty = bool
|
||||||
|
elif dtype.is_floating():
|
||||||
|
ret_ty = float
|
||||||
|
return constexpr(ret_ty(self.value))
|
||||||
|
|
||||||
|
|
||||||
class tensor:
|
class tensor:
|
||||||
# infer dtype from ir type
|
# infer dtype from ir type
|
||||||
|
Reference in New Issue
Block a user