154 lines
3.7 KiB
C++
154 lines
3.7 KiB
C++
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
|
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
|
|
|
#include <map>
|
|
#include <set>
|
|
#include <vector>
|
|
#include <memory>
|
|
#include "triton/tools/graph.h"
|
|
|
|
namespace triton{
|
|
|
|
namespace ir{
|
|
class value;
|
|
class type;
|
|
class module;
|
|
class instruction;
|
|
class phi_node;
|
|
}
|
|
|
|
namespace codegen{
|
|
namespace analysis{
|
|
|
|
class axes;
|
|
class align;
|
|
|
|
enum layout_type_t {
|
|
HMMA_884,
|
|
SCANLINE,
|
|
SHARED
|
|
};
|
|
|
|
struct double_buffer_info_t {
|
|
ir::value* first;
|
|
ir::value* latch;
|
|
ir::phi_node* phi;
|
|
};
|
|
|
|
class layout_visitor;
|
|
class layout_hmma_884_t;
|
|
class layout_scanline_t;
|
|
class layout_shared_t;
|
|
|
|
|
|
class layout_visitor {
|
|
public:
|
|
virtual void visit_layout_hmma_884(layout_hmma_884_t*) = 0;
|
|
virtual void visit_layout_scanline(layout_scanline_t*) = 0;
|
|
virtual void visit_layout_shared(layout_shared_t*) = 0;
|
|
};
|
|
|
|
struct layout_t {
|
|
layout_t(layout_type_t _type,
|
|
const std::vector<int>& _axes,
|
|
const std::vector<unsigned> &_shapes,
|
|
const std::vector<ir::value *> &_values,
|
|
size_t _id,
|
|
analysis::align* align);
|
|
|
|
virtual void accept(layout_visitor* vst) = 0;
|
|
|
|
layout_type_t type;
|
|
std::vector<int> axes;
|
|
std::vector<unsigned> shapes;
|
|
std::vector<ir::value*> values;
|
|
std::vector<int> order;
|
|
size_t id;
|
|
size_t size;
|
|
std::shared_ptr<double_buffer_info_t> double_buffer;
|
|
ir::type *ty;
|
|
size_t pad;
|
|
std::vector<int> mts;
|
|
std::vector<int> nts;
|
|
std::vector<int> fpw;
|
|
std::vector<int> wpt;
|
|
};
|
|
|
|
struct layout_hmma_884_t: public layout_t {
|
|
layout_hmma_884_t(size_t num_warps,
|
|
const std::vector<int>& _axes,
|
|
const std::vector<unsigned>& _shapes,
|
|
const std::vector<ir::value *> &_values,
|
|
size_t _id,
|
|
analysis::align* align);
|
|
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
|
|
};
|
|
|
|
struct layout_scanline_t: public layout_t {
|
|
layout_scanline_t(size_t num_warps,
|
|
const std::vector<int>& _axes,
|
|
const std::vector<unsigned>& _shapes,
|
|
const std::vector<ir::value *> &values,
|
|
size_t _id,
|
|
analysis::align* align);
|
|
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
|
};
|
|
|
|
struct layout_shared_t: public layout_t {
|
|
layout_shared_t(const layout_t *arg,
|
|
const std::vector<int>& _axes,
|
|
const std::vector<unsigned>& _shapes,
|
|
const std::vector<ir::value *> &values,
|
|
ir::type *ty,
|
|
size_t _id,
|
|
analysis::align* align);
|
|
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
|
};
|
|
|
|
|
|
|
|
class layout {
|
|
typedef ir::value* node_t;
|
|
typedef std::map <node_t, std::set<node_t>> graph_t;
|
|
|
|
private:
|
|
// graph creation
|
|
void connect(ir::value *x, ir::value *y);
|
|
void make_graph(ir::instruction *i);
|
|
|
|
void init_hmma_tile(layout_t& layout);
|
|
void init_scanline_tile(layout_t &layout);
|
|
|
|
void create(size_t id, const std::vector<ir::value*>& values);
|
|
|
|
public:
|
|
// constructor
|
|
layout(analysis::axes *axes, analysis::align *align, size_t num_warps);
|
|
|
|
// accessors
|
|
unsigned layout_of(ir::value *value) const;
|
|
const std::vector<ir::value*>& values_of(unsigned id) const;
|
|
size_t num_layouts() const;
|
|
const layout_t* get(ir::value *v) const;
|
|
std::map<size_t, layout_t*> &get_all();
|
|
|
|
// execution
|
|
void run(ir::module &mod);
|
|
|
|
private:
|
|
analysis::axes* axes_;
|
|
analysis::align* align_;
|
|
size_t num_warps_;
|
|
tools::graph<ir::value*> graph_;
|
|
std::map<ir::value*, size_t> groups_;
|
|
std::map<size_t, std::vector<ir::value*>> values_;
|
|
std::map<size_t, layout_t*> layouts_;
|
|
};
|
|
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|