[GENERAL] Cleaned polymorphic structure of layouts analysis pass

This commit is contained in:
Philippe Tillet
2020-01-20 15:15:32 -05:00
parent 382ca2c745
commit 78b98fb7cf
17 changed files with 500 additions and 480 deletions

View File

@@ -174,7 +174,7 @@ inline bool is_trans(ir::value *v) {
generator::generator(analysis::axes *a_axes,
analysis::layout *layouts,
analysis::layouts *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
target *tgt,
@@ -295,7 +295,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
}
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->order[0];
size_t ld = layouts_->get(ptr)->get_order(0);
unsigned alignment = std::max<int>(alignment_->get(ptr, ld), 1);
@@ -337,7 +337,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = layouts_->get(ptr)->order[0];
size_t ld = layouts_->get(ptr)->get_order(0);
unsigned alignment = alignment_->get(ptr, ld);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
@@ -603,7 +603,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
const auto& shapes = dot->get_type()->get_tile_shapes();
machine_layout_hmma_884_t* hmma = (machine_layout_hmma_884_t*)machine_layouts_.at(layouts_->get(dot));
machine_mma884_layout* hmma = (machine_mma884_layout*)machine_layouts_.at(layouts_->get(dot));
TA->set_vector_size(4*hmma->pack_size_0_);
TB->set_vector_size(4*hmma->pack_size_1_);
TA->set_return_mode(true);
@@ -625,8 +625,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
auto ord_a = layouts_->get(dot->get_operand(0))->order;
auto ord_b = layouts_->get(dot->get_operand(1))->order;
auto ord_a = layouts_->get(dot->get_operand(0))->get_order();
auto ord_b = layouts_->get(dot->get_operand(1))->get_order();
bool is_a_trans = is_trans(dot->get_operand(0));
bool is_b_trans = is_trans(dot->get_operand(1));
@@ -655,14 +655,14 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
"{$8, $9}, "
"{$10, $11}, "
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
analysis::layout_hmma_884_t* layout = layouts_->get(dot)->to_hmma884();
analysis::mma884_layout* layout = layouts_->get(dot)->to_mma884();
unsigned fpw_0 = layout->fpw.at(0);
unsigned fpw_1 = layout->fpw.at(1);
unsigned fpw_0 = layout->fpw(0);
unsigned fpw_1 = layout->fpw(1);
unsigned wts_0 = fpw_0 * 8;
unsigned wts_1 = fpw_1 * 8;
unsigned wpt_0 = layout->wpt.at(0);
unsigned wpt_1 = layout->wpt.at(1);
unsigned wpt_0 = layout->wpt(0);
unsigned wpt_1 = layout->wpt(1);
unsigned stride_rep_i = wpt_0 * wts_0;
unsigned stride_rep_j = wpt_1 * wts_1;
unsigned num_rep_i = shapes[0] / stride_rep_i;
@@ -792,7 +792,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(layouts_->get(dot)->type == analysis::HMMA_884)
if(layouts_->get(dot)->to_mma884())
visit_hmma_dot(dot, TA, TB, TD, NK);
else
visit_scanline_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add);
@@ -856,7 +856,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
});
// reduce within blocks
machine_layout_t *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x)));
machine_data_layout *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x)));
shared_tile *stile = (shared_tile*)slayout->create(x);
unsigned depth = stile->get_shapes()[axis];
@@ -926,31 +926,31 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
// pointer to temporary shared memory
Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_);
// layouts
analysis::layout_hmma_884_t* in_layout = layouts_->get(op)->to_hmma884();
analysis::layout_scanline_t* out_layout = layouts_->get(rc)->to_scanline();
analysis::mma884_layout* in_layout = layouts_->get(op)->to_mma884();
analysis::scanline_layout* out_layout = layouts_->get(rc)->to_scanline();
// machine tiles
distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op));
distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc));
// WMMA configuration
long wmma_pt[3] = { 2, 4, 1};
long wmma[3] = { 8*in_layout->wpt[0]*in_layout->fpw[0],
8*in_layout->wpt[1]*in_layout->fpw[1],
long wmma[3] = { 8*in_layout->wpt(0)*in_layout->fpw(0),
8*in_layout->wpt(1)*in_layout->fpw(1),
1};
// Work per thread for input layout
long in_pt[3] = { shape[0] / wmma[0],
shape[1] / wmma[1],
1 };
// Work per thread for output layout
long out_pt[3] = { shape[0] / out_layout->mts[0],
shape[1] / out_layout->mts[1],
long out_pt[3] = { shape[0] / out_layout->mts(0),
shape[1] / out_layout->mts(1),
1 };
if(rank > 2){
wmma[2] = in_layout->wpt[2]*in_layout->fpw[2];
wmma[2] = in_layout->wpt(2)*in_layout->fpw(2);
in_pt[2] = shape[2] / wmma[2];
out_pt[2] = shape[2] / out_layout->mts[2];
out_pt[2] = shape[2] / out_layout->mts(2);
}
// Orders
auto ord = out_layout->order;
auto ord = out_layout->get_order();
if(ord.size() < 3)
ord.push_back(2);
// pointer lanes
@@ -1028,13 +1028,13 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
unsigned vector_size = 1;
ir::value *arg = cts->get_operand(0);
analysis::layout_shared_t* out_layout = layouts_->get(cts)->to_shared();
analysis::layout_scanline_t* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->order;
auto in_order = in_layout->order;
analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared();
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->get_order();
auto in_order = in_layout->get_order();
// tiles
if(out_order == in_order)
vector_size = in_layout->nts.at(in_order[0]);
vector_size = in_layout->nts(in_order[0]);
std::map<unsigned, Value*> packets;
for_each(arg, [&](indices_t idx){
@@ -1180,17 +1180,17 @@ void generator::visit_function(ir::function* fn) {
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
void generator::visit_layout_hmma_884(analysis::mma884_layout* layout) {
machine_layouts_[layout] = new machine_mma884_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout);
}
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
machine_layouts_[layout] = new machine_scanline_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout);
}
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
void generator::visit_layout_shared(analysis::shared_layout* layout) {
machine_layouts_[layout] = new machine_layout_shared_t(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
machine_layouts_[layout] = new machine_shared_layout(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
}
void generator::visit_basic_block(ir::basic_block * block) {
@@ -1230,9 +1230,9 @@ void generator::set_value(ir::value *x, const indices_t& idx, Value* v) {
}
void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
if(shared->double_buffer) {
auto info = *shared->double_buffer;
void generator::finalize_shared_layout(analysis::shared_layout *shared) {
if(shared->get_double_buffer()) {
auto info = *shared->get_double_buffer();
ir::phi_node *phi = info.phi;
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
@@ -1247,8 +1247,8 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
offset->addIncoming(next_offset, llvm_inc_block);
}
else {
unsigned num_bytes = shared->ty->get_primitive_size_in_bits() / 8;
offset->addIncoming(builder_->getInt32(shared->size / (2*num_bytes)), llvm_inc_block);
unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8;
offset->addIncoming(builder_->getInt32(shared->get_size() / (2*num_bytes)), llvm_inc_block);
}
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
}
@@ -1258,7 +1258,7 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
void generator::finalize_function(ir::function *fn) {
// finalize double-buffering
for(const auto& x: layouts_->get_all())
if(auto *shared = dynamic_cast<analysis::layout_shared_t*>(x.second))
if(auto *shared = dynamic_cast<analysis::shared_layout*>(x.second))
finalize_shared_layout(shared);
// finalize phi
for(ir::basic_block *block: fn->blocks())