[codegen][selection] some cleaning
This commit is contained in:
@@ -50,16 +50,15 @@ private:
|
||||
|
||||
public:
|
||||
grids(size_t num_warps, transform::coalesce* reorder);
|
||||
unsigned get_param_group(ir::value *value, unsigned ax);
|
||||
fragment_t get_fragment(ir::value *value, unsigned ax);
|
||||
void copy(ir::value *dst, ir::value *src);
|
||||
void run(ir::module &mod);
|
||||
unsigned get_num_threads();
|
||||
unsigned get_param_group(ir::value *value, unsigned ax);
|
||||
const std::vector<ir::value*> get_grids() const { return grids_; }
|
||||
int get_mts(ir::value *value, unsigned ax);
|
||||
int get_nts(ir::value *value, unsigned ax);
|
||||
int get_fpw(ir::value *value, unsigned ax);
|
||||
int get_wpt(ir::value *value, unsigned ax);
|
||||
int mts(ir::value *value, unsigned ax);
|
||||
int nts(ir::value *value, unsigned ax);
|
||||
int fpw(ir::value *value, unsigned ax);
|
||||
int wpt(ir::value *value, unsigned ax);
|
||||
|
||||
private:
|
||||
|
||||
|
@@ -157,7 +157,11 @@ private:
|
||||
void create_grids(std::vector<ir::value *> &grids,
|
||||
std::map<unsigned, ir::value *> &references,
|
||||
ir::function *fn);
|
||||
void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr);
|
||||
void create_distributed_tile(ir::value *v, Builder &builder);
|
||||
void create_tile(ir::value *v, Builder &builder, std::set<ir::value *> &seen, Value *sh_mem_ptr);
|
||||
void init_strided_scan_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_hmma_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_grids(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
|
||||
|
||||
@@ -195,8 +199,8 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::coalesce* reorder, target *tgt)
|
||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), reorder_(reorder), tgt_(tgt){ }
|
||||
selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::coalesce* reorder, target *tgt, unsigned num_warps)
|
||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ }
|
||||
|
||||
void run(ir::module &src, Module &dst);
|
||||
|
||||
@@ -215,6 +219,7 @@ private:
|
||||
Value *offset_b_j_, *offset_b_k_;
|
||||
unsigned num_packs_0_, num_packs_1_;
|
||||
unsigned pack_size_0_, pack_size_1_;
|
||||
unsigned num_warps_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -156,8 +156,8 @@ grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
|
||||
return STRIDED_SCAN;
|
||||
}
|
||||
|
||||
void grids::connected_components(node_t x, const std::vector<param_ptr_t>& ptr_vec, const std::vector<param_map_t*>& maps, std::set<node_t> &nodes, graph_t &graph, unsigned group_id)
|
||||
{
|
||||
void grids::connected_components(node_t x, const std::vector<param_ptr_t>& ptr_vec, const std::vector<param_map_t*>& maps,
|
||||
std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
groups_[x.first].insert({x.second, group_id});
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
@@ -190,22 +190,18 @@ void grids::copy(ir::value *dst, ir::value *src) {
|
||||
|
||||
|
||||
void grids::run(ir::module &mod) {
|
||||
ir::context &ctx = mod.get_context();
|
||||
// Create metaparameters
|
||||
// Create tiling parameters
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
|
||||
// Build constraints graph
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(i->has_tile_result_or_op())
|
||||
init_c_graph(i);
|
||||
|
||||
// Build phi constraints
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(i->has_tile_result_or_op())
|
||||
init_c_phi(i);
|
||||
|
||||
// Layout parameters
|
||||
unsigned group_id = 0;
|
||||
for(auto x: nodes_)
|
||||
@@ -231,7 +227,7 @@ void grids::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
|
||||
unsigned num_threads = get_num_threads();
|
||||
unsigned num_threads = num_warps_*32;
|
||||
auto clamp = [&](unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); };
|
||||
|
||||
for(ir::value *i: grids_){
|
||||
@@ -242,10 +238,8 @@ void grids::run(ir::module &mod) {
|
||||
unsigned size = i->get_type()->get_tile_num_elements();
|
||||
/* HMMA parameters*/
|
||||
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
|
||||
|
||||
unsigned shape_0 = shapes[order[0]];
|
||||
unsigned shape_1 = shapes[order[1]];
|
||||
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> fpw = {1, 1, 1};
|
||||
@@ -261,7 +255,6 @@ void grids::run(ir::module &mod) {
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
*fpw_[i][d] = fpw[d];
|
||||
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> wpt = {1, 1, 1};
|
||||
@@ -276,15 +269,12 @@ void grids::run(ir::module &mod) {
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
*wpt_[i][d] = wpt[d];
|
||||
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_warps *= *wpt_[i][d];
|
||||
|
||||
if(num_warps_ != effective_num_warps)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
|
||||
}
|
||||
|
||||
/* Scan-line */
|
||||
@@ -356,24 +346,19 @@ void grids::create_grids(std::vector<ir::value*> &grids,
|
||||
grids.push_back(ref.second);
|
||||
}
|
||||
|
||||
|
||||
unsigned grids::get_num_threads() {
|
||||
return num_warps_*32;
|
||||
}
|
||||
|
||||
int grids::get_mts(ir::value *value, unsigned ax) {
|
||||
int grids::mts(ir::value *value, unsigned ax) {
|
||||
return *mts_.at(value).at(ax);
|
||||
}
|
||||
|
||||
int grids::get_nts(ir::value *value, unsigned ax) {
|
||||
int grids::nts(ir::value *value, unsigned ax) {
|
||||
return *nts_.at(value).at(ax);
|
||||
}
|
||||
|
||||
int grids::get_fpw(ir::value *value, unsigned ax) {
|
||||
int grids::fpw(ir::value *value, unsigned ax) {
|
||||
return *fpw_.at(value).at(ax);
|
||||
}
|
||||
|
||||
int grids::get_wpt(ir::value *value, unsigned ax) {
|
||||
int grids::wpt(ir::value *value, unsigned ax) {
|
||||
return *wpt_.at(value).at(ax);
|
||||
}
|
||||
|
||||
|
@@ -57,9 +57,9 @@ unsigned memalloc::get_num_bytes(ir::value *x) {
|
||||
num_elements *= x;
|
||||
size_t depth;
|
||||
if(params_->get_fragment(x, 0) == grids::HMMA_FRAGMENT_C)
|
||||
depth = params_->get_wpt(op, axis);
|
||||
depth = params_->wpt(op, axis);
|
||||
else
|
||||
depth = params_->get_mts(op, axis);
|
||||
depth = params_->mts(op, axis);
|
||||
return num_elements * num_bytes * depth;
|
||||
}
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
|
@@ -571,18 +571,17 @@ inline void to_warps(const std::vector<unsigned> &bs, const std::vector<unsigned
|
||||
}
|
||||
}
|
||||
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
auto order = reorder_->get_order(v);
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
size_t dim = shapes.size();
|
||||
if(params_->get_fragment(v, 0) == analysis::grids::STRIDED_SCAN){
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
std::vector<unsigned> block_size(dim);
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
contiguous[i] = params_->get_nts(v, i);
|
||||
block_size[i] = params_->get_mts(v, i);
|
||||
contiguous[i] = params_->nts(v, i);
|
||||
block_size[i] = params_->mts(v, i);
|
||||
}
|
||||
to_warps(block_size, order, n_warps, warp_size);
|
||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
|
||||
@@ -604,7 +603,10 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
// auto order = reorder_->get_order(v);
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
if(shapes.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
bool is_batched = shapes.size() >= 3;
|
||||
@@ -616,13 +618,13 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
Value *_16 = builder.getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = params_->get_fpw(v, 0);
|
||||
unsigned fpw_1 = params_->get_fpw(v, 1);
|
||||
unsigned fpw_2 = is_batched ? params_->get_fpw(v, 2) : 1;
|
||||
unsigned fpw_0 = params_->fpw(v, 0);
|
||||
unsigned fpw_1 = params_->fpw(v, 1);
|
||||
unsigned fpw_2 = is_batched ? params_->fpw(v, 2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = params_->get_wpt(v, 0);
|
||||
unsigned wpt_1 = params_->get_wpt(v, 1);
|
||||
unsigned wpt_2 = is_batched ? params_->get_wpt(v, 2) : 1;
|
||||
unsigned wpt_0 = params_->wpt(v, 0);
|
||||
unsigned wpt_1 = params_->wpt(v, 1);
|
||||
unsigned wpt_2 = is_batched ? params_->wpt(v, 2) : 1;
|
||||
// hmma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
@@ -708,6 +710,13 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
if(is_batched)
|
||||
axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
if(params_->get_fragment(v, 0) == analysis::grids::STRIDED_SCAN)
|
||||
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
|
||||
else
|
||||
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
|
||||
}
|
||||
|
||||
bool static inline has_phi_user(ir::value *v) {
|
||||
@@ -717,22 +726,13 @@ bool static inline has_phi_user(ir::value *v) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
if(auto *user = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value *op: user->ops()){
|
||||
create_tile(op, builder, seen, sh_mem_ptr);
|
||||
}
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
|
||||
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
unsigned pad = alloc_->is_ld_padded(v);
|
||||
if(pad > 0)
|
||||
shapes[0] += pad;
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
|
||||
// create shared tile
|
||||
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::reduce_inst*>(v)){
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
|
||||
// shared copy
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
|
||||
// phi-node (double-buffering)
|
||||
@@ -769,8 +769,9 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
}
|
||||
}
|
||||
}
|
||||
// create distributed tile
|
||||
else {
|
||||
|
||||
void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
@@ -803,8 +804,19 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
T->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
if(auto *user = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value *op: user->ops())
|
||||
create_tile(op, builder, seen, sh_mem_ptr);
|
||||
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::reduce_inst*>(v))
|
||||
create_shared_tile(v, builder, sh_mem_ptr);
|
||||
else
|
||||
create_distributed_tile(v, builder);
|
||||
}
|
||||
|
||||
void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){
|
||||
@@ -908,7 +920,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
tgt_->add_barrier(module, builder);
|
||||
builder.CreateStore(result, write_ptr);
|
||||
// build result
|
||||
unsigned depth = params_->get_wpt(op, axis);
|
||||
unsigned depth = params_->wpt(op, axis);
|
||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||
// current indices
|
||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||
@@ -1075,12 +1087,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
"{$10, $11}, "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||
|
||||
unsigned fpw_0 = params_->get_fpw(dot, 0);
|
||||
unsigned fpw_1 = params_->get_fpw(dot, 1);
|
||||
unsigned fpw_0 = params_->fpw(dot, 0);
|
||||
unsigned fpw_1 = params_->fpw(dot, 1);
|
||||
unsigned wts_0 = fpw_0 * 8;
|
||||
unsigned wts_1 = fpw_1 * 8;
|
||||
unsigned wpt_0 = params_->get_wpt(dot, 0);
|
||||
unsigned wpt_1 = params_->get_wpt(dot, 1);
|
||||
unsigned wpt_0 = params_->wpt(dot, 0);
|
||||
unsigned wpt_1 = params_->wpt(dot, 1);
|
||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
unsigned num_rep_i = shapes[0] / stride_rep_i;
|
||||
@@ -1457,7 +1469,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(dst_fn),
|
||||
MDString::get(dst_ctx, "maxntidx"),
|
||||
ValueAsMetadata::get(dst_builder.getInt32(params_->get_num_threads()))
|
||||
ValueAsMetadata::get(dst_builder.getInt32(num_warps_*32))
|
||||
};
|
||||
dst.getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(dst_ctx, md_args));
|
||||
|
||||
|
@@ -27,7 +27,7 @@ void vectorize::run(ir::module &mod) {
|
||||
}
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
|
||||
ir::value *x = i->get_operand(0);
|
||||
if(params_->get_nts(x, 0) == 1)
|
||||
if(params_->nts(x, 0) == 1)
|
||||
continue;
|
||||
builder.set_insert_point(i);
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
|
||||
|
@@ -205,7 +205,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, &reorder, target.get());
|
||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, &reorder, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
|
Reference in New Issue
Block a user