[GENERAL] Cleaned polymorphic structure of layouts analysis pass
This commit is contained in:
@@ -174,7 +174,7 @@ inline bool is_trans(ir::value *v) {
|
||||
|
||||
|
||||
generator::generator(analysis::axes *a_axes,
|
||||
analysis::layout *layouts,
|
||||
analysis::layouts *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
target *tgt,
|
||||
@@ -295,7 +295,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
}
|
||||
// find vector size
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->order[0];
|
||||
size_t ld = layouts_->get(ptr)->get_order(0);
|
||||
unsigned alignment = std::max<int>(alignment_->get(ptr, ld), 1);
|
||||
|
||||
|
||||
@@ -337,7 +337,7 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
// find vector size
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = layouts_->get(ptr)->order[0];
|
||||
size_t ld = layouts_->get(ptr)->get_order(0);
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
||||
@@ -603,7 +603,7 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst*) {
|
||||
|
||||
void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) {
|
||||
const auto& shapes = dot->get_type()->get_tile_shapes();
|
||||
machine_layout_hmma_884_t* hmma = (machine_layout_hmma_884_t*)machine_layouts_.at(layouts_->get(dot));
|
||||
machine_mma884_layout* hmma = (machine_mma884_layout*)machine_layouts_.at(layouts_->get(dot));
|
||||
TA->set_vector_size(4*hmma->pack_size_0_);
|
||||
TB->set_vector_size(4*hmma->pack_size_1_);
|
||||
TA->set_return_mode(true);
|
||||
@@ -625,8 +625,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
||||
|
||||
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
|
||||
|
||||
auto ord_a = layouts_->get(dot->get_operand(0))->order;
|
||||
auto ord_b = layouts_->get(dot->get_operand(1))->order;
|
||||
auto ord_a = layouts_->get(dot->get_operand(0))->get_order();
|
||||
auto ord_b = layouts_->get(dot->get_operand(1))->get_order();
|
||||
|
||||
bool is_a_trans = is_trans(dot->get_operand(0));
|
||||
bool is_b_trans = is_trans(dot->get_operand(1));
|
||||
@@ -655,14 +655,14 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
||||
"{$8, $9}, "
|
||||
"{$10, $11}, "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||
analysis::layout_hmma_884_t* layout = layouts_->get(dot)->to_hmma884();
|
||||
analysis::mma884_layout* layout = layouts_->get(dot)->to_mma884();
|
||||
|
||||
unsigned fpw_0 = layout->fpw.at(0);
|
||||
unsigned fpw_1 = layout->fpw.at(1);
|
||||
unsigned fpw_0 = layout->fpw(0);
|
||||
unsigned fpw_1 = layout->fpw(1);
|
||||
unsigned wts_0 = fpw_0 * 8;
|
||||
unsigned wts_1 = fpw_1 * 8;
|
||||
unsigned wpt_0 = layout->wpt.at(0);
|
||||
unsigned wpt_1 = layout->wpt.at(1);
|
||||
unsigned wpt_0 = layout->wpt(0);
|
||||
unsigned wpt_1 = layout->wpt(1);
|
||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
unsigned num_rep_i = shapes[0] / stride_rep_i;
|
||||
@@ -792,7 +792,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
if(NK != 1) {
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
if(layouts_->get(dot)->type == analysis::HMMA_884)
|
||||
if(layouts_->get(dot)->to_mma884())
|
||||
visit_hmma_dot(dot, TA, TB, TD, NK);
|
||||
else
|
||||
visit_scanline_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||
@@ -856,7 +856,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
});
|
||||
|
||||
// reduce within blocks
|
||||
machine_layout_t *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x)));
|
||||
machine_data_layout *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x)));
|
||||
shared_tile *stile = (shared_tile*)slayout->create(x);
|
||||
unsigned depth = stile->get_shapes()[axis];
|
||||
|
||||
@@ -926,31 +926,31 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
|
||||
// pointer to temporary shared memory
|
||||
Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_);
|
||||
// layouts
|
||||
analysis::layout_hmma_884_t* in_layout = layouts_->get(op)->to_hmma884();
|
||||
analysis::layout_scanline_t* out_layout = layouts_->get(rc)->to_scanline();
|
||||
analysis::mma884_layout* in_layout = layouts_->get(op)->to_mma884();
|
||||
analysis::scanline_layout* out_layout = layouts_->get(rc)->to_scanline();
|
||||
// machine tiles
|
||||
distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op));
|
||||
distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc));
|
||||
// WMMA configuration
|
||||
long wmma_pt[3] = { 2, 4, 1};
|
||||
long wmma[3] = { 8*in_layout->wpt[0]*in_layout->fpw[0],
|
||||
8*in_layout->wpt[1]*in_layout->fpw[1],
|
||||
long wmma[3] = { 8*in_layout->wpt(0)*in_layout->fpw(0),
|
||||
8*in_layout->wpt(1)*in_layout->fpw(1),
|
||||
1};
|
||||
// Work per thread for input layout
|
||||
long in_pt[3] = { shape[0] / wmma[0],
|
||||
shape[1] / wmma[1],
|
||||
1 };
|
||||
// Work per thread for output layout
|
||||
long out_pt[3] = { shape[0] / out_layout->mts[0],
|
||||
shape[1] / out_layout->mts[1],
|
||||
long out_pt[3] = { shape[0] / out_layout->mts(0),
|
||||
shape[1] / out_layout->mts(1),
|
||||
1 };
|
||||
if(rank > 2){
|
||||
wmma[2] = in_layout->wpt[2]*in_layout->fpw[2];
|
||||
wmma[2] = in_layout->wpt(2)*in_layout->fpw(2);
|
||||
in_pt[2] = shape[2] / wmma[2];
|
||||
out_pt[2] = shape[2] / out_layout->mts[2];
|
||||
out_pt[2] = shape[2] / out_layout->mts(2);
|
||||
}
|
||||
// Orders
|
||||
auto ord = out_layout->order;
|
||||
auto ord = out_layout->get_order();
|
||||
if(ord.size() < 3)
|
||||
ord.push_back(2);
|
||||
// pointer lanes
|
||||
@@ -1028,13 +1028,13 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
|
||||
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
unsigned vector_size = 1;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
analysis::layout_shared_t* out_layout = layouts_->get(cts)->to_shared();
|
||||
analysis::layout_scanline_t* in_layout = layouts_->get(arg)->to_scanline();
|
||||
auto out_order = out_layout->order;
|
||||
auto in_order = in_layout->order;
|
||||
analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared();
|
||||
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
|
||||
auto out_order = out_layout->get_order();
|
||||
auto in_order = in_layout->get_order();
|
||||
// tiles
|
||||
if(out_order == in_order)
|
||||
vector_size = in_layout->nts.at(in_order[0]);
|
||||
vector_size = in_layout->nts(in_order[0]);
|
||||
|
||||
std::map<unsigned, Value*> packets;
|
||||
for_each(arg, [&](indices_t idx){
|
||||
@@ -1180,17 +1180,17 @@ void generator::visit_function(ir::function* fn) {
|
||||
|
||||
|
||||
|
||||
void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
||||
void generator::visit_layout_hmma_884(analysis::mma884_layout* layout) {
|
||||
machine_layouts_[layout] = new machine_mma884_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) {
|
||||
machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout);
|
||||
void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
||||
machine_layouts_[layout] = new machine_scanline_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout);
|
||||
}
|
||||
|
||||
void generator::visit_layout_shared(analysis::layout_shared_t* layout) {
|
||||
void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
||||
|
||||
machine_layouts_[layout] = new machine_layout_shared_t(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
|
||||
machine_layouts_[layout] = new machine_shared_layout(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_);
|
||||
}
|
||||
|
||||
void generator::visit_basic_block(ir::basic_block * block) {
|
||||
@@ -1230,9 +1230,9 @@ void generator::set_value(ir::value *x, const indices_t& idx, Value* v) {
|
||||
}
|
||||
|
||||
|
||||
void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
|
||||
if(shared->double_buffer) {
|
||||
auto info = *shared->double_buffer;
|
||||
void generator::finalize_shared_layout(analysis::shared_layout *shared) {
|
||||
if(shared->get_double_buffer()) {
|
||||
auto info = *shared->get_double_buffer();
|
||||
ir::phi_node *phi = info.phi;
|
||||
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
|
||||
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
|
||||
@@ -1247,8 +1247,8 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
|
||||
offset->addIncoming(next_offset, llvm_inc_block);
|
||||
}
|
||||
else {
|
||||
unsigned num_bytes = shared->ty->get_primitive_size_in_bits() / 8;
|
||||
offset->addIncoming(builder_->getInt32(shared->size / (2*num_bytes)), llvm_inc_block);
|
||||
unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8;
|
||||
offset->addIncoming(builder_->getInt32(shared->get_size() / (2*num_bytes)), llvm_inc_block);
|
||||
}
|
||||
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
|
||||
}
|
||||
@@ -1258,7 +1258,7 @@ void generator::finalize_shared_layout(analysis::layout_shared_t *shared) {
|
||||
void generator::finalize_function(ir::function *fn) {
|
||||
// finalize double-buffering
|
||||
for(const auto& x: layouts_->get_all())
|
||||
if(auto *shared = dynamic_cast<analysis::layout_shared_t*>(x.second))
|
||||
if(auto *shared = dynamic_cast<analysis::shared_layout*>(x.second))
|
||||
finalize_shared_layout(shared);
|
||||
// finalize phi
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
|
Reference in New Issue
Block a user