diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 961aea725..74a617af9 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -179,7 +179,6 @@ private: void lower_reshape(ir::reshape_inst* x, LLVMContext &ctx, Function *fn, Builder &builder); void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); - void lower_vectorize(ir::vectorize_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); void lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); void lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); // matrix multiply diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 79c8214c7..762fd90db 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -799,8 +799,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) { axes[d].values = {builder.getInt32(0)}; } } - bool vectorize = dynamic_cast(v); - distributed_tile *T = new distributed_tile(ty, shapes, tiles_->order(v), axes, builder, vectorize); + distributed_tile *T = new distributed_tile(ty, shapes, tiles_->order(v), axes, builder, false); bool is_inserted = tmap_.insert({v, T}).second; // constant range if(is_inserted && dynamic_cast(v)){ @@ -890,8 +889,25 @@ void selection::lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, F void selection::lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand()); tile *scalars = tmap_.at(x->get_value_operand()); +// size_t ld = tiles_->order(x->get_pointer_operand())[0]; +// unsigned vector_size = 2; +// // vectorize pointers +// std::map ptr_packets; +// ptrs->for_each([&](indices_t idx){ +// unsigned linear = ptrs->get_linear_index(idx); +// unsigned id = linear / vector_size; +// if(linear % vector_size == 0) { +// Value *ptr = ptrs->get_value(idx); +// ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(ptr->getType()->getPointerElementType(), vector_size), +// ptr->getType()->getPointerAddressSpace())); +// ptr_packets[id] = ptr; +// } +// }); +// ((shared_tile*)(scalars))->set_vector_size(vector_size); +// ((shared_tile*)(scalars))->set_return_mode(true); + // extract result element ptrs->for_each([&](indices_t idx){ - builder.CreateStore(scalars->get_value(idx), ptrs->get_value(idx)); + builder.CreateStore(scalars->get_value(idx), ptrs->get_value(idx)); }); } @@ -1018,10 +1034,13 @@ void selection::lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Functio }); } -void selection::lower_vectorize(ir::vectorize_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { - distributed_tile* result = (distributed_tile*)tmap_.at(x); - distributed_tile* in = (distributed_tile*)tmap_.at(x->get_operand(0)); - unsigned vector_size = result->axis(0).contiguous; +void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { + shared_tile* result = (shared_tile*)tmap_.at(x); + ir::value *arg = x->get_operand(0); + distributed_tile* in = (distributed_tile*)tmap_.at(arg); + size_t ld = tiles_->order(arg)[0]; + unsigned vector_size = in->axis(ld).contiguous; + std::map packets; in->for_each([&](indices_t idx){ unsigned linear = in->get_linear_index(idx); @@ -1031,7 +1050,7 @@ void selection::lower_vectorize(ir::vectorize_inst *x, LLVMContext &ctx, Functio packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size)); packets[id] = builder.CreateInsertElement(packets.at(id), in_value, linear % vector_size); }); - result->for_each([&](indices_t idx){ + in->for_each([&](indices_t idx){ unsigned linear = in->get_linear_index(idx); unsigned id = linear / vector_size; if(linear % vector_size == 0) @@ -1039,14 +1058,6 @@ void selection::lower_vectorize(ir::vectorize_inst *x, LLVMContext &ctx, Functio }); } -void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { - shared_tile* result = (shared_tile*)tmap_.at(x); - distributed_tile* in = (distributed_tile*)tmap_.at(x->get_operand(0)); - in->for_each([&](indices_t idx){ - result->set_value(idx, in->get_value(idx)); - }); -} - void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { shared_tile* result = (shared_tile*)tmap_.at(x); distributed_tile* in = (distributed_tile*)tmap_.at(x->get_operand(0)); @@ -1400,8 +1411,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & lower_splat(x, ctx, fn, builder); else if(auto *x = dynamic_cast(ins)) lower_broadcast(x, ctx, fn, builder); - else if(auto *x = dynamic_cast(ins)) - lower_vectorize(x, ctx, fn, builder); else if(auto *x = dynamic_cast(ins)) lower_copy_to_shared(x, ctx, fn, builder); else if(auto* x = dynamic_cast(ins)) diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 0e435e663..b0d1a3521 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -87,7 +87,7 @@ void coalesce::run(ir::module &mod) { // rematerialize operands std::map seen; for(ir::value *op: r->ops()) - rematerialize(op, mod.get_builder(), seen); + r->replace_uses_of_with(op, rematerialize(op, mod.get_builder(), seen)); // copy to shared if load auto& inst_list = r->get_parent()->get_inst_list(); auto pos = ++std::find(inst_list.begin(), inst_list.end(), r); diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 7908f8ec7..7db1e1af1 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -208,7 +208,6 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts); codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &tiles); codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info); - codegen::transform::vectorize vectorize(&tiles); codegen::transform::dce dce; codegen::transform::peephole peephole; codegen::transform::reassociate reassociate(&align); @@ -235,7 +234,6 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c return std::unique_ptr(); shmem_barriers.run(module); dce.run(module); - vectorize.run(module); dce.run(module); axes.run(module); layouts.run(module); diff --git a/tests/common/src/copy.h b/tests/common/src/copy.h index 58651a84f..b1d571b51 100644 --- a/tests/common/src/copy.h +++ b/tests/common/src/copy.h @@ -37,8 +37,8 @@ void copy2d(TYPE * X __noalias __readonly __aligned(16), int ridn = get_program_id(1); int rm[TM] = ridm * TM + 0 ... TM; int rn[TN] = ridn * TN + 0 ... TN; - TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; - TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldy; + TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :] ; + TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldy; *py = *px; } )";