more progress

This commit is contained in:
Philippe Tillet
2019-09-08 21:36:54 -04:00
parent 3d78810d5e
commit 3daef1726d
3 changed files with 18 additions and 9 deletions

View File

@@ -17,6 +17,11 @@ namespace ir{
}
namespace codegen{
namespace transform{
class reorder;
}
namespace analysis{
class grids {
@@ -36,12 +41,12 @@ private:
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
void create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, ir::value*> &references,
std::map<std::pair<unsigned, std::vector<unsigned> >, triton::ir::value *> &references,
ir::function *fn);
public:
grids(size_t num_warps);
grids(size_t num_warps, transform::reorder* reorder);
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
unsigned get_param_group(ir::value *value, unsigned ax);
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
@@ -60,6 +65,8 @@ private:
std::vector<ir::value*> grids_;
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
size_t num_warps_;
transform::reorder* reorder_;
};

View File

@@ -1,5 +1,6 @@
#include <algorithm>
#include <cstdlib>
#include "triton/codegen/transform/reorder.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
@@ -15,7 +16,7 @@ namespace triton{
namespace codegen{
namespace analysis{
grids::grids(size_t num_warps): num_warps_(num_warps)
grids::grids(size_t num_warps, transform::reorder *reorder): num_warps_(num_warps), reorder_(reorder)
{ }
bool is_hmma(ir::value *v){
@@ -157,7 +158,6 @@ grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
}
void grids::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
std::cout << "connected component: " << x.first->get_name() << " " << x.second << std::endl;
groups_[x.first].insert({x.second, group_id});
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
@@ -225,7 +225,7 @@ void grids::run(ir::module &mod) {
}
for(ir::function *fn: mod.get_function_list()){
std::map<unsigned, ir::value*> references;
std::map<std::pair<unsigned, std::vector<unsigned>>, ir::value*> references;
create_grids(grids_, references, fn);
}
@@ -317,7 +317,8 @@ void grids::run(ir::module &mod) {
void grids::create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, ir::value*> &references,
std::map<std::pair<unsigned,
std::vector<unsigned>>, ir::value*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){
@@ -331,6 +332,7 @@ void grids::create_grids(std::vector<ir::value*> &grids,
std::set<ir::value*> seen;
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
{
auto order = reorder_->get_order(v);
// skip
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
return;
@@ -344,7 +346,7 @@ void grids::create_grids(std::vector<ir::value*> &grids,
if(shapes[d] == 1)
continue;
unsigned x = get_param_group(v, d);
ir::value *&r = references[x];
ir::value *&r = references[{x, order}];
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
r = v;
}

View File

@@ -194,12 +194,12 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
std::unique_ptr<codegen::target> target = context->device()->make_target();
// create passes
codegen::analysis::grids grids(opt.num_warps);
codegen::analysis::meminfo shmem_info;
codegen::analysis::liveness shmem_liveness(&shmem_info);
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
codegen::analysis::align alignment_info;
codegen::transform::reorder reorder(&alignment_info, &shmem_info);
codegen::analysis::grids grids(opt.num_warps, &reorder);
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
codegen::transform::vectorize vectorize(&grids);
codegen::transform::dce dce;