[codegen][analysis] cleaned-up tiling formalism
This commit is contained in:
49
include/triton/codegen/analysis/axes.h
Normal file
49
include/triton/codegen/analysis/axes.h
Normal file
@@ -0,0 +1,49 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class module;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class axes {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
||||
private:
|
||||
void add_constraint(node_t x, node_t y);
|
||||
void init_c_phi(ir::instruction *i);
|
||||
void init_c_graph(ir::instruction *v);
|
||||
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||
|
||||
public:
|
||||
axes();
|
||||
void run(ir::module &mod);
|
||||
unsigned get(ir::value *value, unsigned ax);
|
||||
bool has(ir::value *value, unsigned ax);
|
||||
|
||||
private:
|
||||
// constraints graph
|
||||
graph_t dependencies_;
|
||||
std::set<node_t> nodes_;
|
||||
// parameter groups
|
||||
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,89 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_TUNE_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_TUNE_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class module;
|
||||
class instruction;
|
||||
class function;
|
||||
class metaparameter;
|
||||
class constant_int;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace transform{
|
||||
class coalesce;
|
||||
}
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class grids {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
typedef std::shared_ptr<int> param_ptr_t;
|
||||
typedef std::map<ir::value*, std::map<int, param_ptr_t>> param_map_t;
|
||||
|
||||
public:
|
||||
enum fragment_t{
|
||||
STRIDED_SCAN,
|
||||
HMMA_FRAGMENT_C
|
||||
};
|
||||
|
||||
private:
|
||||
void add_constraint(node_t x, node_t y);
|
||||
void init_c_phi(ir::instruction *i);
|
||||
void init_c_graph(ir::instruction *v);
|
||||
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
|
||||
void connected_components(node_t x, const std::vector<param_ptr_t>& params, const std::vector<param_map_t*>& maps, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||
void create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<unsigned, triton::ir::value *> &references,
|
||||
ir::function *fn);
|
||||
|
||||
|
||||
public:
|
||||
grids(size_t num_warps, transform::coalesce* coalesce);
|
||||
void run(ir::module &mod);
|
||||
const std::vector<ir::value*> get() const { return grids_; }
|
||||
fragment_t fragment_of(ir::value *value, unsigned ax);
|
||||
unsigned group_of(ir::value *value, unsigned ax);
|
||||
int mts(ir::value *value, unsigned ax);
|
||||
int nts(ir::value *value, unsigned ax);
|
||||
int fpw(ir::value *value, unsigned ax);
|
||||
int wpt(ir::value *value, unsigned ax);
|
||||
void copy(ir::value *dst, ir::value *src);
|
||||
|
||||
private:
|
||||
|
||||
transform::coalesce* coalesce_;
|
||||
// number of warps
|
||||
size_t num_warps_;
|
||||
// grids
|
||||
std::vector<ir::value*> grids_;
|
||||
// grid parameters
|
||||
param_map_t fpw_;
|
||||
param_map_t wpt_;
|
||||
param_map_t mts_;
|
||||
param_map_t nts_;
|
||||
// constraints graph
|
||||
graph_t dependencies_;
|
||||
std::set<node_t> nodes_;
|
||||
// fragments
|
||||
std::map<node_t, fragment_t> fragments_;
|
||||
// parameter groups
|
||||
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
57
include/triton/codegen/analysis/layout.h
Normal file
57
include/triton/codegen/analysis/layout.h
Normal file
@@ -0,0 +1,57 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class module;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class axes;
|
||||
|
||||
class layout {
|
||||
typedef ir::value* node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
||||
private:
|
||||
// connected components
|
||||
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned id);
|
||||
// list the axes of the given value
|
||||
std::set<int> axes_of(ir::value *value);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layout(analysis::axes *axes);
|
||||
// run the passes
|
||||
void run(ir::module &mod);
|
||||
// get the layout ID of the given value
|
||||
unsigned id(ir::value *value) const;
|
||||
// get the values associates with the given ID
|
||||
const std::vector<ir::value*>& values(unsigned id) const;
|
||||
// get number of groups
|
||||
size_t get_num_groups() const;
|
||||
|
||||
private:
|
||||
analysis::axes* axes_;
|
||||
graph_t dependencies_;
|
||||
std::set<node_t> nodes_;
|
||||
std::map<ir::value*, unsigned> groups_;
|
||||
std::map<unsigned, std::vector<ir::value*>> values_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -15,15 +15,15 @@ namespace ir{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class grids;
|
||||
class tiles;
|
||||
|
||||
class liveness;
|
||||
class meminfo;
|
||||
|
||||
class memalloc {
|
||||
public:
|
||||
memalloc(liveness *live, meminfo *buffer_info, grids *params)
|
||||
: liveness_(live), buffer_info_(buffer_info), params_(params){ }
|
||||
memalloc(liveness *live, meminfo *buffer_info, tiles *params)
|
||||
: liveness_(live), buffer_info_(buffer_info), tiles_(params){ }
|
||||
// utilities
|
||||
unsigned num_bytes(ir::value *x);
|
||||
unsigned is_ld_padded(ir::value* x);
|
||||
@@ -40,7 +40,7 @@ private:
|
||||
// dependences
|
||||
liveness *liveness_;
|
||||
meminfo *buffer_info_;
|
||||
grids *params_;
|
||||
tiles *tiles_;
|
||||
};
|
||||
|
||||
}
|
||||
|
68
include/triton/codegen/analysis/tiles.h
Normal file
68
include/triton/codegen/analysis/tiles.h
Normal file
@@ -0,0 +1,68 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_TILES_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_TILES_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class module;
|
||||
class instruction;
|
||||
class function;
|
||||
class metaparameter;
|
||||
class constant_int;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace transform{
|
||||
class coalesce;
|
||||
}
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class axes;
|
||||
class layout;
|
||||
|
||||
class tiles {
|
||||
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
|
||||
private:
|
||||
void init_hmma_tile(ir::value *i);
|
||||
void init_scanline_tile(ir::value *i);
|
||||
|
||||
public:
|
||||
tiles(size_t num_warps, transform::coalesce* coalesce, analysis::axes* axes, analysis::layout* layout);
|
||||
void run(ir::module &mod);
|
||||
bool hmma(ir::value *value);
|
||||
int mts(ir::value *value, unsigned ax);
|
||||
int nts(ir::value *value, unsigned ax);
|
||||
int fpw(ir::value *value, unsigned ax);
|
||||
int wpt(ir::value *value, unsigned ax);
|
||||
const std::map<int, ir::value*>& largest();
|
||||
|
||||
private:
|
||||
// dependencies
|
||||
analysis::layout* layout_;
|
||||
analysis::axes* axes_;
|
||||
transform::coalesce* coalesce_;
|
||||
// number of warps
|
||||
size_t num_warps_;
|
||||
// tile properties
|
||||
std::map<int, bool> hmma_;
|
||||
std::map<int, ir::value*> largest_;
|
||||
std::map<int, int> fpw_;
|
||||
std::map<int, int> wpt_;
|
||||
std::map<int, int> mts_;
|
||||
std::map<int, int> nts_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -43,10 +43,12 @@ namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class grids;
|
||||
class tiles;
|
||||
class align;
|
||||
class memalloc;
|
||||
class meminfo;
|
||||
class axes;
|
||||
class layout;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
@@ -199,8 +201,12 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::coalesce* reorder, target *tgt, unsigned num_warps)
|
||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ }
|
||||
selection(analysis::memalloc *alloc, analysis::tiles *tiles, analysis::meminfo *buffer_info,
|
||||
analysis::align *alignment, analysis::axes *axes, analysis::layout *layouts,
|
||||
transform::coalesce* reorder, target *tgt, unsigned num_warps)
|
||||
: alloc_(alloc), tiles_(tiles), buffer_info_(buffer_info),
|
||||
alignment_(alignment), a_axes_(axes), layouts_(layouts),
|
||||
reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ }
|
||||
|
||||
void run(ir::module &src, Module &dst);
|
||||
|
||||
@@ -208,7 +214,9 @@ private:
|
||||
vmap_t vmap_;
|
||||
tmap_t tmap_;
|
||||
analysis::memalloc *alloc_;
|
||||
analysis::grids *params_;
|
||||
analysis::tiles *tiles_;
|
||||
analysis::axes *a_axes_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::meminfo *buffer_info_;
|
||||
analysis::align *alignment_;
|
||||
transform::coalesce *reorder_;
|
||||
|
@@ -19,7 +19,7 @@ class getelementptr_inst;
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class grids;
|
||||
class tiles;
|
||||
class align;
|
||||
}
|
||||
|
||||
@@ -37,11 +37,10 @@ private:
|
||||
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
|
||||
|
||||
public:
|
||||
reassociate(analysis::align* align, analysis::grids *params);
|
||||
reassociate(analysis::align* align);
|
||||
void run(ir::module& module);
|
||||
|
||||
private:
|
||||
analysis::grids* params_;
|
||||
analysis::align* align_;
|
||||
};
|
||||
|
||||
|
@@ -10,18 +10,18 @@ namespace ir {
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class grids;
|
||||
class tiles;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class vectorize {
|
||||
public:
|
||||
vectorize(analysis::grids *params): params_(params){}
|
||||
vectorize(analysis::tiles *params): params_(params){}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::grids *params_;
|
||||
analysis::tiles *params_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -63,6 +63,7 @@ public:
|
||||
unsigned get_primitive_size_in_bits() const;
|
||||
type *get_scalar_ty() const;
|
||||
const tile_shapes_t& get_tile_shapes() const;
|
||||
const size_t get_tile_rank() const;
|
||||
unsigned get_tile_num_elements() const;
|
||||
type *get_tile_element_ty() const;
|
||||
unsigned get_pointer_address_space() const;
|
||||
|
@@ -11,7 +11,7 @@
|
||||
// codegen
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/analysis/memalloc.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/meminfo.h"
|
||||
@@ -45,7 +45,7 @@ class translation_unit;
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class grids;
|
||||
class tiles;
|
||||
}
|
||||
}
|
||||
|
||||
|
166
lib/codegen/analysis/axes.cc
Normal file
166
lib/codegen/analysis/axes.cc
Normal file
@@ -0,0 +1,166 @@
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/driver/device.h"
|
||||
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
axes::axes() {}
|
||||
|
||||
void axes::add_constraint(node_t x, node_t y) {
|
||||
size_t shape_x = 1;
|
||||
size_t shape_y = 1;
|
||||
if(x.first->get_type()->is_tile_ty())
|
||||
shape_x = x.first->get_type()->get_tile_shapes()[x.second];
|
||||
if(y.first->get_type()->is_tile_ty())
|
||||
shape_y = y.first->get_type()->get_tile_shapes()[y.second];
|
||||
if(shape_x == 1 && shape_y == 1)
|
||||
return;
|
||||
dependencies_[x].insert(y);
|
||||
dependencies_[y].insert(x);
|
||||
nodes_.insert(x);
|
||||
nodes_.insert(y);
|
||||
}
|
||||
|
||||
void axes::init_c_graph(ir::instruction *v) {
|
||||
// Reference shape
|
||||
ir::type::tile_shapes_t shapes;
|
||||
if(auto *store = dynamic_cast<ir::store_inst*>(v))
|
||||
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
|
||||
else if(auto *atom = dynamic_cast<ir::atomic_add_inst*>(v))
|
||||
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
|
||||
else if(dynamic_cast<ir::downcast_inst*>(v))
|
||||
return;
|
||||
else if(dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
return;
|
||||
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v)) {
|
||||
unsigned axis = reduce->get_axis();
|
||||
ir::value *arg = reduce->get_operand(0);
|
||||
auto in_shapes = arg->get_type()->get_tile_shapes();
|
||||
unsigned current = 0;
|
||||
for(unsigned i = 0; i < in_shapes.size(); i++){
|
||||
if(i == axis)
|
||||
continue;
|
||||
add_constraint({reduce, current++}, {arg, i});
|
||||
}
|
||||
return;
|
||||
}
|
||||
else
|
||||
shapes = v->get_type()->get_tile_shapes();
|
||||
// Reshape
|
||||
if(dynamic_cast<ir::reshape_inst*>(v)) {
|
||||
ir::value *op = v->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
if(shapes[i] == 1){
|
||||
add_constraint({v, i}, {v, i});
|
||||
}
|
||||
else if(!is_skewed &&
|
||||
shapes[i] == op_shapes[current])
|
||||
add_constraint({v, i}, {op, current++});
|
||||
else{
|
||||
is_skewed = true;
|
||||
add_constraint({v, i}, {v, i});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Splat
|
||||
else if(dynamic_cast<ir::splat_inst*>(v)){
|
||||
return;
|
||||
}
|
||||
// Trans
|
||||
else if(auto *x = dynamic_cast<ir::trans_inst*>(v)){
|
||||
ir::value *op = v->get_operand(0);
|
||||
auto perm = x->get_perm();
|
||||
for(unsigned i = 0; i < perm.size(); i++)
|
||||
add_constraint({v, perm[i]->get_value()}, {op, i});
|
||||
}
|
||||
// Broadcast
|
||||
else if(dynamic_cast<ir::broadcast_inst*>(v)){
|
||||
ir::value *op = v->get_operand(0);
|
||||
ir::type *op_ty = op->get_type();
|
||||
const auto& op_shapes = op_ty->get_tile_shapes();
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
if(op_shapes[i] == shapes[i] && v != op)
|
||||
add_constraint({v, i}, {op, i});
|
||||
}
|
||||
}
|
||||
// Matrix multiplication
|
||||
else if(dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *A = v->get_operand(0);
|
||||
ir::value *B = v->get_operand(1);
|
||||
ir::value *D = v->get_operand(2);
|
||||
for(unsigned i = 0; i < shapes.size(); i++)
|
||||
add_constraint({v, i}, {D, i});
|
||||
for(unsigned i = 2; i < shapes.size(); i++){
|
||||
add_constraint({v, i}, {A, i});
|
||||
add_constraint({v, i}, {B, i});
|
||||
}
|
||||
}
|
||||
// Element-wise
|
||||
else if(dynamic_cast<ir::user*>(v)) {
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
std::vector<ir::value*> ops = v->ops();
|
||||
for(ir::value* op: ops)
|
||||
add_constraint({v, i}, {op, i});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void axes::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
groups_[x.first].insert({x.second, group_id});
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
for(const node_t &y: graph[x])
|
||||
connected_components(y, nodes, graph, group_id);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned axes::get(ir::value *value, unsigned ax) {
|
||||
unsigned result = groups_.at(value).at(ax);
|
||||
return result;
|
||||
}
|
||||
|
||||
bool axes::has(ir::value *value, unsigned ax) {
|
||||
auto it = groups_.find(value);
|
||||
if(it == groups_.end())
|
||||
return false;
|
||||
auto iit = it->second.find(ax);
|
||||
if(iit == it->second.end())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
void axes::run(ir::module &mod) {
|
||||
nodes_.clear();
|
||||
dependencies_.clear();
|
||||
groups_.clear();
|
||||
// Create graph
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Build constraints graph
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(i->has_tile_result_or_op())
|
||||
init_c_graph(i);
|
||||
}
|
||||
// Axes
|
||||
unsigned group_id = 0;
|
||||
while(!nodes_.empty())
|
||||
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@@ -1,367 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/driver/device.h"
|
||||
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
grids::grids(size_t num_warps, transform::coalesce *reorder): num_warps_(num_warps), coalesce_(reorder)
|
||||
{ }
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
ir::type *a_ty = a->get_type();
|
||||
ir::value *b = x->get_operand(1);
|
||||
ir::type *b_ty = b->get_type();
|
||||
// inputs have to be FP16
|
||||
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
|
||||
// reduction has to be multiple of 4: TODO
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void grids::add_constraint(node_t x, node_t y) {
|
||||
dependencies_[x].insert(y);
|
||||
dependencies_[y].insert(x);
|
||||
nodes_.insert(x);
|
||||
nodes_.insert(y);
|
||||
}
|
||||
|
||||
void grids::init_c_phi(ir::instruction *v) {
|
||||
// Phi Nodes: all the incoming value share the result layout
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(v))
|
||||
for(ir::value *op: phi->ops())
|
||||
for(unsigned k = 0; k < phi->get_type()->get_tile_shapes().size(); k++)
|
||||
if(dependencies_.find({op, k}) != dependencies_.end()
|
||||
|| dependencies_.find({phi, k}) != dependencies_.end()){
|
||||
add_constraint({phi, k}, {op, k});
|
||||
}
|
||||
}
|
||||
|
||||
void grids::init_c_graph(ir::instruction *v) {
|
||||
// Reference shape
|
||||
ir::type::tile_shapes_t shapes;
|
||||
if(auto *store = dynamic_cast<ir::store_inst*>(v))
|
||||
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
|
||||
else if(auto *atom = dynamic_cast<ir::atomic_add_inst*>(v))
|
||||
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
|
||||
else if(dynamic_cast<ir::downcast_inst*>(v))
|
||||
return;
|
||||
else if(dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
return;
|
||||
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v)) {
|
||||
unsigned axis = reduce->get_axis();
|
||||
ir::value *arg = reduce->get_operand(0);
|
||||
auto in_shapes = arg->get_type()->get_tile_shapes();
|
||||
unsigned current = 0;
|
||||
for(unsigned i = 0; i < in_shapes.size(); i++){
|
||||
if(i == axis)
|
||||
continue;
|
||||
add_constraint({reduce, current++}, {arg, i});
|
||||
}
|
||||
return;
|
||||
}
|
||||
else
|
||||
shapes = v->get_type()->get_tile_shapes();
|
||||
// Reshape
|
||||
if(dynamic_cast<ir::reshape_inst*>(v)) {
|
||||
ir::value *op = v->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
if(shapes[i] == 1){
|
||||
add_constraint({v, i}, {v, i});
|
||||
}
|
||||
else if(!is_skewed &&
|
||||
shapes[i] == op_shapes[current])
|
||||
add_constraint({v, i}, {op, current++});
|
||||
else{
|
||||
is_skewed = true;
|
||||
add_constraint({v, i}, {v, i});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Splat
|
||||
else if(dynamic_cast<ir::splat_inst*>(v)){
|
||||
return;
|
||||
}
|
||||
// Trans
|
||||
else if(auto *x = dynamic_cast<ir::trans_inst*>(v)){
|
||||
ir::value *op = v->get_operand(0);
|
||||
auto perm = x->get_perm();
|
||||
for(unsigned i = 0; i < perm.size(); i++)
|
||||
add_constraint({v, perm[i]->get_value()}, {op, i});
|
||||
}
|
||||
// Broadcast
|
||||
else if(dynamic_cast<ir::broadcast_inst*>(v)){
|
||||
ir::value *op = v->get_operand(0);
|
||||
ir::type *op_ty = op->get_type();
|
||||
const auto& op_shapes = op_ty->get_tile_shapes();
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
if(op_shapes[i] == shapes[i] && v != op)
|
||||
add_constraint({v, i}, {op, i});
|
||||
}
|
||||
}
|
||||
// Matrix multiplication
|
||||
else if(dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *A = v->get_operand(0);
|
||||
ir::value *B = v->get_operand(1);
|
||||
ir::value *D = v->get_operand(2);
|
||||
for(unsigned i = 0; i < shapes.size(); i++)
|
||||
add_constraint({v, i}, {D, i});
|
||||
for(unsigned i = 2; i < shapes.size(); i++){
|
||||
add_constraint({v, i}, {A, i});
|
||||
add_constraint({v, i}, {B, i});
|
||||
}
|
||||
}
|
||||
// Element-wise
|
||||
else if(dynamic_cast<ir::user*>(v)) {
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
std::vector<ir::value*> ops = v->ops();
|
||||
for(ir::value* op: ops)
|
||||
add_constraint({v, i}, {op, i});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
|
||||
std::list<node_t> work;
|
||||
std::set<node_t> seen;
|
||||
work.push_back(x);
|
||||
while(!work.empty()){
|
||||
node_t current = work.back();
|
||||
if(is_hmma(current.first))
|
||||
return HMMA_FRAGMENT_C;
|
||||
work.pop_back();
|
||||
seen.insert(current);
|
||||
for(node_t y: graph[current]){
|
||||
if(seen.find(y) == seen.end())
|
||||
work.push_back(y);
|
||||
}
|
||||
}
|
||||
return STRIDED_SCAN;
|
||||
}
|
||||
|
||||
void grids::connected_components(node_t x, const std::vector<param_ptr_t>& ptr_vec, const std::vector<param_map_t*>& maps,
|
||||
std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
groups_[x.first].insert({x.second, group_id});
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
for(unsigned i = 0; i < ptr_vec.size(); i++)
|
||||
(*maps[i])[x.first][x.second] = ptr_vec[i];
|
||||
for(const node_t &y: graph[x])
|
||||
connected_components(y, ptr_vec, maps, nodes, graph, group_id);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned grids::group_of(ir::value *value, unsigned ax) {
|
||||
unsigned result = groups_.at(value).at(ax);
|
||||
return result;
|
||||
}
|
||||
|
||||
grids::fragment_t grids::fragment_of(ir::value *value, unsigned ax) {
|
||||
return fragments_.at({value, ax});
|
||||
}
|
||||
|
||||
|
||||
//TODO: This shouldn't exist!
|
||||
void grids::copy(ir::value *dst, ir::value *src) {
|
||||
mts_[dst] = mts_[src];
|
||||
nts_[dst] = nts_[src];
|
||||
fpw_[dst] = fpw_[src];
|
||||
wpt_[dst] = wpt_[src];
|
||||
groups_[dst] = groups_[src];
|
||||
fragments_[{dst, 0}] = fragments_[{src, 0}];
|
||||
}
|
||||
|
||||
|
||||
void grids::run(ir::module &mod) {
|
||||
// Create tiling parameters
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Build constraints graph
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(i->has_tile_result_or_op())
|
||||
init_c_graph(i);
|
||||
// Build phi constraints
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(i->has_tile_result_or_op())
|
||||
init_c_phi(i);
|
||||
// Layout parameters
|
||||
unsigned group_id = 0;
|
||||
for(auto x: nodes_)
|
||||
fragments_[x] = get_fragmentation_type(x, dependencies_);
|
||||
while(!nodes_.empty()) {
|
||||
node_t node = *nodes_.begin();
|
||||
if(fragments_[node] == STRIDED_SCAN) {
|
||||
param_ptr_t nts(new int(-1));
|
||||
param_ptr_t mts(new int(-1));
|
||||
connected_components(node, {nts, mts}, {&nts_, &mts_}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
else {
|
||||
param_ptr_t fpw(new int(-1));
|
||||
param_ptr_t wpt(new int(-1));
|
||||
connected_components(node, {fpw, wpt}, {&fpw_, &wpt_}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::map<unsigned, ir::value*> references;
|
||||
create_grids(grids_, references, fn);
|
||||
}
|
||||
|
||||
|
||||
unsigned num_threads = num_warps_*32;
|
||||
auto clamp = [&](unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); };
|
||||
|
||||
for(ir::value *i: grids_){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
auto order = coalesce_->get_order(i);
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
unsigned size = i->get_type()->get_tile_num_elements();
|
||||
/* HMMA parameters*/
|
||||
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
|
||||
unsigned shape_0 = shapes[order[0]];
|
||||
unsigned shape_1 = shapes[order[1]];
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> fpw = {1, 1, 1};
|
||||
std::vector<unsigned> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_0/8)*(shape_1/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw;
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
|
||||
}while(fpw_nm1 != fpw);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
*fpw_[i][d] = fpw[d];
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> wpt = {1, 1, 1};
|
||||
std::vector<unsigned> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt;
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
|
||||
}while(wpt_nm1 != wpt);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
*wpt_[i][d] = wpt[d];
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_warps *= *wpt_[i][d];
|
||||
if(num_warps_ != effective_num_warps)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
/* Scan-line */
|
||||
else{
|
||||
unsigned ld = order[0];
|
||||
unsigned current = num_threads;
|
||||
*nts_[i][ld] = clamp(size / num_threads, 1, 4);
|
||||
*mts_[i][ld] = clamp(current, 1, shapes[ld] / *nts_[i][ld]);
|
||||
current = current / *mts_[i][ld];
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
ld = order[d];
|
||||
*nts_[i][ld] = 1;
|
||||
*mts_[i][ld] = clamp(current, 1, shapes[ld]);
|
||||
current = current / *mts_[i][ld];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_threads *= *mts_[i][d];
|
||||
if(num_threads != effective_num_threads)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
void grids::create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<unsigned, ir::value*> &references,
|
||||
ir::function *fn) {
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
unsigned result = 0;
|
||||
for(auto shape: v->get_type()->get_tile_shapes()) {
|
||||
result += (shape > 1)? shape : 0;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
// bind references
|
||||
std::set<ir::value*> seen;
|
||||
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
|
||||
{
|
||||
// skip
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
// recurse
|
||||
if(auto *user = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value *op: user->ops())
|
||||
bind_references(op);
|
||||
// bind
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] == 1)
|
||||
continue;
|
||||
unsigned x = group_of(v, d);
|
||||
ir::value *&r = references[x];
|
||||
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
||||
r = v;
|
||||
}
|
||||
};
|
||||
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
bind_references(i);
|
||||
|
||||
// create grid
|
||||
for(auto &ref: references)
|
||||
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
|
||||
grids.push_back(ref.second);
|
||||
}
|
||||
|
||||
int grids::mts(ir::value *value, unsigned ax) {
|
||||
return *mts_.at(value).at(ax);
|
||||
}
|
||||
|
||||
int grids::nts(ir::value *value, unsigned ax) {
|
||||
return *nts_.at(value).at(ax);
|
||||
}
|
||||
|
||||
int grids::fpw(ir::value *value, unsigned ax) {
|
||||
return *fpw_.at(value).at(ax);
|
||||
}
|
||||
|
||||
int grids::wpt(ir::value *value, unsigned ax) {
|
||||
return *wpt_.at(value).at(ax);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
96
lib/codegen/analysis/layout.cc
Normal file
96
lib/codegen/analysis/layout.cc
Normal file
@@ -0,0 +1,96 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
// axes
|
||||
std::set<int> layout::axes_of(ir::value *value) {
|
||||
auto ty = value->get_type();
|
||||
// rank of value
|
||||
size_t rank = 0;
|
||||
if(ty->is_tile_ty())
|
||||
rank = ty->get_tile_rank();
|
||||
// create result
|
||||
std::set<int> result;
|
||||
for(size_t d = 0; d < rank; d++){
|
||||
if(axes_->has(value, d))
|
||||
result.insert(axes_->get(value, d));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// connected components
|
||||
void layout::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
groups_[x] = group_id;
|
||||
values_[group_id].push_back(x);
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
for(const node_t &y: graph[x])
|
||||
connected_components(y, nodes, graph, group_id);
|
||||
}
|
||||
}
|
||||
|
||||
// constructor
|
||||
layout::layout(analysis::axes *axes)
|
||||
: axes_(axes) { }
|
||||
|
||||
// get group id
|
||||
unsigned layout::id(ir::value *value) const
|
||||
{ return groups_.at(value); }
|
||||
|
||||
// get values
|
||||
const std::vector<ir::value*>& layout::values(unsigned id) const
|
||||
{ return values_.at(id); }
|
||||
|
||||
// get number of groups
|
||||
size_t layout::get_num_groups() const
|
||||
{ return values_.size(); }
|
||||
|
||||
// run
|
||||
void layout::run(ir::module &mod) {
|
||||
nodes_.clear();
|
||||
dependencies_.clear();
|
||||
groups_.clear();
|
||||
values_.clear();
|
||||
// Create graph
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list()) {
|
||||
// skip scalars
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
// add an edge between i and the operands that share an axis
|
||||
std::set<int> i_axes = axes_of(i);
|
||||
nodes_.insert(i);
|
||||
for(ir::value* op: i->ops()){
|
||||
if(!op->get_type()->is_tile_ty())
|
||||
continue;
|
||||
nodes_.insert(op);
|
||||
std::set<int> op_axes = axes_of(op);
|
||||
std::set<int> common;
|
||||
std::set_intersection(i_axes.begin(), i_axes.end(),
|
||||
op_axes.begin(), op_axes.end(),
|
||||
std::inserter(common, common.begin()));
|
||||
if(!common.empty() || !op->get_type()->is_tile_ty()){
|
||||
dependencies_[i].insert(op);
|
||||
dependencies_[op].insert(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Grids
|
||||
unsigned group_id = 0;
|
||||
while(!nodes_.empty()){
|
||||
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -2,7 +2,7 @@
|
||||
#include "triton/codegen/analysis/memalloc.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/meminfo.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/value.h"
|
||||
@@ -20,7 +20,7 @@ unsigned memalloc::is_ld_padded(ir::value *x) {
|
||||
}
|
||||
for(ir::user* user: x->get_users())
|
||||
if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
|
||||
bool is_hmma = params_->fragment_of(user, 0) == grids::HMMA_FRAGMENT_C;
|
||||
bool is_hmma = tiles_->hmma(user);
|
||||
bool is_op_0 = x == dot->get_operand(0);
|
||||
bool is_op_1 = x == dot->get_operand(1);
|
||||
if(is_hmma && is_op_0){
|
||||
@@ -56,10 +56,10 @@ unsigned memalloc::num_bytes(ir::value *x) {
|
||||
for(auto x: shapes)
|
||||
num_elements *= x;
|
||||
size_t depth;
|
||||
if(params_->fragment_of(x, 0) == grids::HMMA_FRAGMENT_C)
|
||||
depth = params_->wpt(op, axis);
|
||||
if(tiles_->hmma(x))
|
||||
depth = tiles_->wpt(op, axis);
|
||||
else
|
||||
depth = params_->mts(op, axis);
|
||||
depth = tiles_->mts(op, axis);
|
||||
return num_elements * num_bytes * depth;
|
||||
}
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
|
176
lib/codegen/analysis/tiles.cc
Normal file
176
lib/codegen/analysis/tiles.cc
Normal file
@@ -0,0 +1,176 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/driver/device.h"
|
||||
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
tiles::tiles(size_t num_warps, transform::coalesce *reorder, analysis::axes *axes, analysis::layout *layout):
|
||||
num_warps_(num_warps), coalesce_(reorder), axes_(axes), layout_(layout)
|
||||
{ }
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
ir::type *a_ty = a->get_type();
|
||||
ir::value *b = x->get_operand(1);
|
||||
ir::type *b_ty = b->get_type();
|
||||
result = a_ty->get_scalar_ty()->is_half_ty() &&
|
||||
b_ty->get_scalar_ty()->is_half_ty();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool tiles::hmma(ir::value *value) {
|
||||
return hmma_.at(layout_->id(value));
|
||||
}
|
||||
|
||||
int tiles::mts(ir::value *value, unsigned ax) {
|
||||
return mts_.at(axes_->get(value, ax));
|
||||
}
|
||||
|
||||
int tiles::nts(ir::value *value, unsigned ax) {
|
||||
return nts_.at(axes_->get(value, ax));
|
||||
}
|
||||
|
||||
int tiles::fpw(ir::value *value, unsigned ax) {
|
||||
return fpw_.at(axes_->get(value, ax));
|
||||
}
|
||||
|
||||
int tiles::wpt(ir::value *value, unsigned ax) {
|
||||
return wpt_.at(axes_->get(value, ax));
|
||||
}
|
||||
|
||||
const std::map<int, ir::value*>& tiles::largest() {
|
||||
return largest_;
|
||||
}
|
||||
|
||||
|
||||
unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
||||
return std::min(std::max(x, lo), hi);
|
||||
}
|
||||
|
||||
|
||||
void tiles::init_hmma_tile(ir::value *i) {
|
||||
auto order = coalesce_->get_order(i);
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
unsigned shape_0 = shapes[order[0]];
|
||||
unsigned shape_1 = shapes[order[1]];
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> fpw = {1, 1, 1};
|
||||
std::vector<unsigned> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_0/8)*(shape_1/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw;
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
|
||||
}while(fpw_nm1 != fpw);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
fpw_[axes_->get(i, d)] = fpw[d];
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> wpt = {1, 1, 1};
|
||||
std::vector<unsigned> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt;
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
|
||||
}while(wpt_nm1 != wpt);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
wpt_[axes_->get(i, d)] = wpt[d];
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_warps *= wpt_[axes_->get(i, d)];
|
||||
if(num_warps_ != effective_num_warps)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
void tiles::init_scanline_tile(ir::value *i) {
|
||||
auto order = coalesce_->get_order(i);
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
unsigned size = i->get_type()->get_tile_num_elements();
|
||||
unsigned ld = order[0];
|
||||
unsigned num_threads = num_warps_*32;
|
||||
unsigned current = num_threads;
|
||||
nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4);
|
||||
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]);
|
||||
current = current / mts_[axes_->get(i, ld)];
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
ld = order[d];
|
||||
nts_[axes_->get(i, ld)] = 1;
|
||||
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]);
|
||||
current = current / mts_[axes_->get(i, ld)];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_threads *= mts_[axes_->get(i, d)];
|
||||
if(num_threads != effective_num_threads)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
void tiles::run(ir::module &) {
|
||||
hmma_.clear();
|
||||
largest_.clear();
|
||||
size_t num_groups = layout_->get_num_groups();
|
||||
// find out which groups require hmma layout
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values(i);
|
||||
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma);
|
||||
}
|
||||
// find out which value is the largest in each group
|
||||
// std::vector<unsigned> axes;
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values(i);
|
||||
auto rank = [](ir::value* v) {
|
||||
ir::type *ty = v->get_type();
|
||||
size_t ret = 0;
|
||||
if(ty->is_tile_ty())
|
||||
for(int s: ty->get_tile_shapes())
|
||||
ret += s > 1;
|
||||
return ret;
|
||||
};
|
||||
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
|
||||
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
||||
}
|
||||
|
||||
// tiling parameters
|
||||
for(auto x: largest_){
|
||||
ir::value *i = x.second;
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
/* HMMA parameters*/
|
||||
if(hmma_[x.first])
|
||||
init_hmma_tile(i);
|
||||
else
|
||||
init_scanline_tile(i);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,6 +1,8 @@
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/analysis/memalloc.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
@@ -584,8 +586,8 @@ void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
contiguous[i] = params_->nts(v, i);
|
||||
block_size[i] = params_->mts(v, i);
|
||||
contiguous[i] = tiles_->nts(v, i);
|
||||
block_size[i] = tiles_->mts(v, i);
|
||||
}
|
||||
to_warps(block_size, order, n_warps, warp_size);
|
||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
|
||||
@@ -604,7 +606,7 @@ void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value
|
||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[params_->group_of(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
|
||||
axes_[a_axes_->get(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -622,13 +624,13 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
|
||||
Value *_16 = builder.getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = params_->fpw(v, 0);
|
||||
unsigned fpw_1 = params_->fpw(v, 1);
|
||||
unsigned fpw_2 = is_batched ? params_->fpw(v, 2) : 1;
|
||||
unsigned fpw_0 = tiles_->fpw(v, 0);
|
||||
unsigned fpw_1 = tiles_->fpw(v, 1);
|
||||
unsigned fpw_2 = is_batched ? tiles_->fpw(v, 2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = params_->wpt(v, 0);
|
||||
unsigned wpt_1 = params_->wpt(v, 1);
|
||||
unsigned wpt_2 = is_batched ? params_->wpt(v, 2) : 1;
|
||||
unsigned wpt_0 = tiles_->wpt(v, 0);
|
||||
unsigned wpt_1 = tiles_->wpt(v, 1);
|
||||
unsigned wpt_2 = is_batched ? tiles_->wpt(v, 2) : 1;
|
||||
// hmma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
@@ -709,18 +711,18 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[params_->group_of(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[params_->group_of(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
axes_[a_axes_->get(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[a_axes_->get(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[params_->group_of(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
axes_[a_axes_->get(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
if(params_->fragment_of(v, 0) == analysis::grids::STRIDED_SCAN)
|
||||
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
|
||||
else
|
||||
if(tiles_->hmma(v))
|
||||
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
|
||||
else
|
||||
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
|
||||
}
|
||||
|
||||
/* -------------------
|
||||
@@ -780,7 +782,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = params_->group_of(v, d);
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
@@ -831,8 +833,8 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
|
||||
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
|
||||
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
|
||||
// create grid
|
||||
for(ir::value* i: params_->get())
|
||||
init_axes(i, builder, u_thread_warp_id, u_warp_id);
|
||||
for(auto x: tiles_->largest())
|
||||
init_axes(x.second, builder, u_thread_warp_id, u_warp_id);
|
||||
// create tile
|
||||
std::set<ir::value*> seen;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
@@ -915,7 +917,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||
for(auto& x: partial) {
|
||||
// current element being computed
|
||||
Value *lane = axes_.at(params_->group_of(op, axis)).thread_id;
|
||||
Value *lane = axes_.at(a_axes_->get(op, axis)).thread_id;
|
||||
Value *&result = x.second;
|
||||
indices_t write_idx = x.first;
|
||||
write_idx.insert(write_idx.begin() + axis, lane);
|
||||
@@ -928,7 +930,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
tgt_->add_barrier(module, builder);
|
||||
builder.CreateStore(result, write_ptr);
|
||||
// build result
|
||||
unsigned depth = params_->wpt(op, axis);
|
||||
unsigned depth = tiles_->wpt(op, axis);
|
||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||
// current indices
|
||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||
@@ -1095,12 +1097,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
"{$10, $11}, "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||
|
||||
unsigned fpw_0 = params_->fpw(dot, 0);
|
||||
unsigned fpw_1 = params_->fpw(dot, 1);
|
||||
unsigned fpw_0 = tiles_->fpw(dot, 0);
|
||||
unsigned fpw_1 = tiles_->fpw(dot, 1);
|
||||
unsigned wts_0 = fpw_0 * 8;
|
||||
unsigned wts_1 = fpw_1 * 8;
|
||||
unsigned wpt_0 = params_->wpt(dot, 0);
|
||||
unsigned wpt_1 = params_->wpt(dot, 1);
|
||||
unsigned wpt_0 = tiles_->wpt(dot, 0);
|
||||
unsigned wpt_1 = tiles_->wpt(dot, 1);
|
||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
unsigned num_rep_i = shapes[0] / stride_rep_i;
|
||||
@@ -1241,10 +1243,11 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
|
||||
if(NK != 1) {
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
if(params_->fragment_of(dot, 0) == analysis::grids::STRIDED_SCAN)
|
||||
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||
else
|
||||
if(tiles_->hmma(dot))
|
||||
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
|
||||
else
|
||||
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||
|
||||
}
|
||||
else {
|
||||
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
|
||||
|
@@ -2,7 +2,6 @@
|
||||
#include <iostream>
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -90,11 +89,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
new_rhs = builder.create_splat(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
if(new_value != old_value){
|
||||
params_->copy(new_value, old_value);
|
||||
params_->copy(new_lhs, old_value);
|
||||
params_->copy(new_rhs, old_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,11 +121,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
if(is_cst(rrhs))
|
||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
|
||||
}
|
||||
if(new_value != old_value){
|
||||
params_->copy(new_value, old_value);
|
||||
params_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
|
||||
params_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
|
||||
}
|
||||
}
|
||||
|
||||
// extract constant and non-constant
|
||||
@@ -156,8 +145,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
return new_value;
|
||||
}
|
||||
|
||||
reassociate::reassociate(analysis::align *align, analysis::grids* params)
|
||||
: params_(params), align_(align)
|
||||
reassociate::reassociate(analysis::align *align): align_(align)
|
||||
{ }
|
||||
|
||||
|
||||
@@ -183,9 +171,6 @@ void reassociate::run(ir::module &mod) {
|
||||
ir::value* static_range = ir::make_range_sta::get(old_range);
|
||||
ir::value* new_range = builder.create_add(dyn_range, static_range);
|
||||
old_range->replace_all_uses_with(new_range);
|
||||
params_->copy(dyn_range, old_range);
|
||||
params_->copy(static_range, old_range);
|
||||
params_->copy(new_range, old_range);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,9 +199,6 @@ void reassociate::run(ir::module &mod) {
|
||||
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
|
||||
ir::value* broadcast = builder.create_broadcast(cst, shapes);
|
||||
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
|
||||
params_->copy(ndyn, rt);
|
||||
params_->copy(nsta, rt);
|
||||
params_->copy(broadcast, rt);
|
||||
infos[rt] = cst_info{ndyn, nsta};
|
||||
}
|
||||
}
|
||||
@@ -236,8 +218,6 @@ void reassociate::run(ir::module &mod) {
|
||||
builder.set_insert_point(pz);
|
||||
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
|
||||
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
||||
params_->copy(dyn_ptr, pz);
|
||||
params_->copy(sta_ptr, pz);
|
||||
pz->replace_all_uses_with(sta_ptr);
|
||||
infos[sta_ptr].dyn_ptr = dyn_ptr;
|
||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
||||
@@ -252,8 +232,6 @@ void reassociate::run(ir::module &mod) {
|
||||
ir::value *off = *pz->idx_begin();
|
||||
ir::value *pz_dyn = builder.create_gep(dyn, {off});
|
||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
|
||||
params_->copy(pz_dyn, pz);
|
||||
params_->copy(pz_sta, pz);
|
||||
pz->replace_all_uses_with(pz_sta);
|
||||
infos[pz_sta].dyn_ptr = pz_dyn;
|
||||
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
||||
@@ -298,12 +276,6 @@ void reassociate::run(ir::module &mod) {
|
||||
ir::value *neg_off = builder.create_neg(off);
|
||||
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
|
||||
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
|
||||
// copy parameters
|
||||
params_->copy(pz_dyn, pz);
|
||||
params_->copy(((ir::instruction*)neg_off)->get_operand(0), off);
|
||||
params_->copy(neg_off, off);
|
||||
params_->copy(phi_dyn, phi);
|
||||
params_->copy(phi_sta, phi);
|
||||
infos[phi_sta].dyn_ptr = phi_dyn;
|
||||
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
|
||||
replaced.insert(phi);
|
||||
|
@@ -1,5 +1,5 @@
|
||||
#include "triton/codegen/transform/vectorize.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -23,7 +23,6 @@ void vectorize::run(ir::module &mod) {
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
|
||||
x->replace_all_uses_with(rx);
|
||||
rx->set_operand(0, x);
|
||||
params_->copy(rx, x);
|
||||
}
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
|
||||
ir::value *x = i->get_operand(0);
|
||||
@@ -33,7 +32,6 @@ void vectorize::run(ir::module &mod) {
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
|
||||
x->replace_all_uses_with(rx);
|
||||
rx->set_operand(0, x);
|
||||
params_->copy(rx, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -73,6 +73,10 @@ const type::tile_shapes_t &type::get_tile_shapes() const {
|
||||
return ((tile_type*)this)->get_shapes();
|
||||
}
|
||||
|
||||
const size_t type::get_tile_rank() const {
|
||||
return get_tile_shapes().size();
|
||||
}
|
||||
|
||||
unsigned type::get_tile_num_elements() const {
|
||||
const tile_shapes_t& shapes = get_tile_shapes();
|
||||
unsigned result = 1;
|
||||
|
@@ -3,6 +3,9 @@
|
||||
#include <regex>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
@@ -192,49 +195,54 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
||||
|
||||
std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::context *context, const options_t& opt) {
|
||||
std::unique_ptr<codegen::target> target = context->device()->make_target();
|
||||
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
// create passes
|
||||
codegen::analysis::meminfo shmem_info;
|
||||
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
||||
codegen::analysis::align alignment_info;
|
||||
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
||||
codegen::transform::coalesce coalesce(&alignment_info, &shmem_info);
|
||||
codegen::analysis::grids grids(opt.num_warps, &coalesce);
|
||||
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
|
||||
codegen::analysis::axes axes;
|
||||
codegen::analysis::layout layouts(&axes);
|
||||
codegen::analysis::tiles tiles(opt.num_warps, &coalesce, &axes, &layouts);
|
||||
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &tiles);
|
||||
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
||||
codegen::transform::vectorize vectorize(&grids);
|
||||
codegen::transform::vectorize vectorize(&tiles);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, &coalesce, target.get(), opt.num_warps);
|
||||
codegen::transform::reassociate reassociate(&alignment_info);
|
||||
codegen::selection selection(&shmem_allocation, &tiles, &shmem_info, &alignment_info, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
alignment_info.run(module);
|
||||
if(target->is_gpu())
|
||||
shmem_info.run(module);
|
||||
shmem_info.run(module);
|
||||
coalesce.run(module);
|
||||
dce.run(module);
|
||||
grids.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
tiles.run(module);
|
||||
alignment_info.run(module);
|
||||
reassociate.run(module);
|
||||
dce.run(module);
|
||||
peephole.run(module);
|
||||
if(target->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
shmem_allocation.run();
|
||||
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
return std::unique_ptr<driver::module>();
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
shmem_allocation.run();
|
||||
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
return std::unique_ptr<driver::module>();
|
||||
shmem_barriers.run(module);
|
||||
dce.run(module);
|
||||
vectorize.run(module);
|
||||
dce.run(module);
|
||||
alignment_info.run(module);
|
||||
coalesce.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
tiles.run(module);
|
||||
selection.run(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
|
@@ -45,10 +45,10 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||
opt.defines.push_back({"TM", {"64", "128"}});
|
||||
opt.defines.push_back({"TN", {"64", "128"}});
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.num_warps = {2, 4, 8};
|
||||
opt.num_warps = {4};
|
||||
// create function
|
||||
rt::function function(src::dot, opt);
|
||||
// benchmark available libraries
|
||||
|
Reference in New Issue
Block a user