more progress
This commit is contained in:
@@ -17,6 +17,11 @@ namespace ir{
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
|
||||||
|
namespace transform{
|
||||||
|
class reorder;
|
||||||
|
}
|
||||||
|
|
||||||
namespace analysis{
|
namespace analysis{
|
||||||
|
|
||||||
class grids {
|
class grids {
|
||||||
@@ -36,12 +41,12 @@ private:
|
|||||||
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
|
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
|
||||||
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||||
void create_grids(std::vector<ir::value*> &grids,
|
void create_grids(std::vector<ir::value*> &grids,
|
||||||
std::map<unsigned, ir::value*> &references,
|
std::map<std::pair<unsigned, std::vector<unsigned> >, triton::ir::value *> &references,
|
||||||
ir::function *fn);
|
ir::function *fn);
|
||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
grids(size_t num_warps);
|
grids(size_t num_warps, transform::reorder* reorder);
|
||||||
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||||
unsigned get_param_group(ir::value *value, unsigned ax);
|
unsigned get_param_group(ir::value *value, unsigned ax);
|
||||||
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
|
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
|
||||||
@@ -60,6 +65,8 @@ private:
|
|||||||
std::vector<ir::value*> grids_;
|
std::vector<ir::value*> grids_;
|
||||||
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
|
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
|
||||||
size_t num_warps_;
|
size_t num_warps_;
|
||||||
|
transform::reorder* reorder_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
#include "triton/codegen/transform/reorder.h"
|
||||||
#include "triton/codegen/analysis/grid.h"
|
#include "triton/codegen/analysis/grid.h"
|
||||||
#include "triton/ir/instructions.h"
|
#include "triton/ir/instructions.h"
|
||||||
#include "triton/ir/type.h"
|
#include "triton/ir/type.h"
|
||||||
@@ -15,7 +16,7 @@ namespace triton{
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
namespace analysis{
|
namespace analysis{
|
||||||
|
|
||||||
grids::grids(size_t num_warps): num_warps_(num_warps)
|
grids::grids(size_t num_warps, transform::reorder *reorder): num_warps_(num_warps), reorder_(reorder)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
bool is_hmma(ir::value *v){
|
bool is_hmma(ir::value *v){
|
||||||
@@ -157,7 +158,6 @@ grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
|
|||||||
}
|
}
|
||||||
|
|
||||||
void grids::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
void grids::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||||
std::cout << "connected component: " << x.first->get_name() << " " << x.second << std::endl;
|
|
||||||
groups_[x.first].insert({x.second, group_id});
|
groups_[x.first].insert({x.second, group_id});
|
||||||
if(nodes.find(x) != nodes.end()){
|
if(nodes.find(x) != nodes.end()){
|
||||||
nodes.erase(x);
|
nodes.erase(x);
|
||||||
@@ -225,7 +225,7 @@ void grids::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for(ir::function *fn: mod.get_function_list()){
|
for(ir::function *fn: mod.get_function_list()){
|
||||||
std::map<unsigned, ir::value*> references;
|
std::map<std::pair<unsigned, std::vector<unsigned>>, ir::value*> references;
|
||||||
create_grids(grids_, references, fn);
|
create_grids(grids_, references, fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -317,7 +317,8 @@ void grids::run(ir::module &mod) {
|
|||||||
|
|
||||||
|
|
||||||
void grids::create_grids(std::vector<ir::value*> &grids,
|
void grids::create_grids(std::vector<ir::value*> &grids,
|
||||||
std::map<unsigned, ir::value*> &references,
|
std::map<std::pair<unsigned,
|
||||||
|
std::vector<unsigned>>, ir::value*> &references,
|
||||||
ir::function *fn) {
|
ir::function *fn) {
|
||||||
// get number of dimensions greater than 1
|
// get number of dimensions greater than 1
|
||||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||||
@@ -331,6 +332,7 @@ void grids::create_grids(std::vector<ir::value*> &grids,
|
|||||||
std::set<ir::value*> seen;
|
std::set<ir::value*> seen;
|
||||||
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
|
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
|
||||||
{
|
{
|
||||||
|
auto order = reorder_->get_order(v);
|
||||||
// skip
|
// skip
|
||||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||||
return;
|
return;
|
||||||
@@ -344,7 +346,7 @@ void grids::create_grids(std::vector<ir::value*> &grids,
|
|||||||
if(shapes[d] == 1)
|
if(shapes[d] == 1)
|
||||||
continue;
|
continue;
|
||||||
unsigned x = get_param_group(v, d);
|
unsigned x = get_param_group(v, d);
|
||||||
ir::value *&r = references[x];
|
ir::value *&r = references[{x, order}];
|
||||||
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
||||||
r = v;
|
r = v;
|
||||||
}
|
}
|
||||||
|
@@ -194,12 +194,12 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
std::unique_ptr<codegen::target> target = context->device()->make_target();
|
std::unique_ptr<codegen::target> target = context->device()->make_target();
|
||||||
|
|
||||||
// create passes
|
// create passes
|
||||||
codegen::analysis::grids grids(opt.num_warps);
|
|
||||||
codegen::analysis::meminfo shmem_info;
|
codegen::analysis::meminfo shmem_info;
|
||||||
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
||||||
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
|
|
||||||
codegen::analysis::align alignment_info;
|
codegen::analysis::align alignment_info;
|
||||||
codegen::transform::reorder reorder(&alignment_info, &shmem_info);
|
codegen::transform::reorder reorder(&alignment_info, &shmem_info);
|
||||||
|
codegen::analysis::grids grids(opt.num_warps, &reorder);
|
||||||
|
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
|
||||||
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
||||||
codegen::transform::vectorize vectorize(&grids);
|
codegen::transform::vectorize vectorize(&grids);
|
||||||
codegen::transform::dce dce;
|
codegen::transform::dce dce;
|
||||||
|
Reference in New Issue
Block a user