more progress
This commit is contained in:
@@ -17,6 +17,11 @@ namespace ir{
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace transform{
|
||||
class reorder;
|
||||
}
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class grids {
|
||||
@@ -36,12 +41,12 @@ private:
|
||||
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
|
||||
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||
void create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<unsigned, ir::value*> &references,
|
||||
std::map<std::pair<unsigned, std::vector<unsigned> >, triton::ir::value *> &references,
|
||||
ir::function *fn);
|
||||
|
||||
|
||||
public:
|
||||
grids(size_t num_warps);
|
||||
grids(size_t num_warps, transform::reorder* reorder);
|
||||
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
unsigned get_param_group(ir::value *value, unsigned ax);
|
||||
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
|
||||
@@ -60,6 +65,8 @@ private:
|
||||
std::vector<ir::value*> grids_;
|
||||
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
|
||||
size_t num_warps_;
|
||||
transform::reorder* reorder_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
@@ -15,7 +16,7 @@ namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
grids::grids(size_t num_warps): num_warps_(num_warps)
|
||||
grids::grids(size_t num_warps, transform::reorder *reorder): num_warps_(num_warps), reorder_(reorder)
|
||||
{ }
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
@@ -157,7 +158,6 @@ grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
|
||||
}
|
||||
|
||||
void grids::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
std::cout << "connected component: " << x.first->get_name() << " " << x.second << std::endl;
|
||||
groups_[x.first].insert({x.second, group_id});
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
@@ -225,7 +225,7 @@ void grids::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::map<unsigned, ir::value*> references;
|
||||
std::map<std::pair<unsigned, std::vector<unsigned>>, ir::value*> references;
|
||||
create_grids(grids_, references, fn);
|
||||
}
|
||||
|
||||
@@ -317,7 +317,8 @@ void grids::run(ir::module &mod) {
|
||||
|
||||
|
||||
void grids::create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<unsigned, ir::value*> &references,
|
||||
std::map<std::pair<unsigned,
|
||||
std::vector<unsigned>>, ir::value*> &references,
|
||||
ir::function *fn) {
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
@@ -331,6 +332,7 @@ void grids::create_grids(std::vector<ir::value*> &grids,
|
||||
std::set<ir::value*> seen;
|
||||
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
|
||||
{
|
||||
auto order = reorder_->get_order(v);
|
||||
// skip
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
@@ -344,7 +346,7 @@ void grids::create_grids(std::vector<ir::value*> &grids,
|
||||
if(shapes[d] == 1)
|
||||
continue;
|
||||
unsigned x = get_param_group(v, d);
|
||||
ir::value *&r = references[x];
|
||||
ir::value *&r = references[{x, order}];
|
||||
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
||||
r = v;
|
||||
}
|
||||
|
@@ -194,12 +194,12 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
std::unique_ptr<codegen::target> target = context->device()->make_target();
|
||||
|
||||
// create passes
|
||||
codegen::analysis::grids grids(opt.num_warps);
|
||||
codegen::analysis::meminfo shmem_info;
|
||||
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
||||
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
|
||||
codegen::analysis::align alignment_info;
|
||||
codegen::transform::reorder reorder(&alignment_info, &shmem_info);
|
||||
codegen::analysis::grids grids(opt.num_warps, &reorder);
|
||||
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
|
||||
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
||||
codegen::transform::vectorize vectorize(&grids);
|
||||
codegen::transform::dce dce;
|
||||
|
Reference in New Issue
Block a user