[codegen][analysis] cleaned-up tiling formalism

This commit is contained in:
Philippe Tillet
2019-09-15 21:14:14 -04:00
parent 031f4dfe96
commit 8d37a55a21
21 changed files with 710 additions and 561 deletions

View File

@@ -0,0 +1,49 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
#include <map>
#include <set>
#include <vector>
#include <memory>
namespace triton{
namespace ir{
class value;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
class axes {
typedef std::pair<ir::value*, unsigned> node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
private:
void add_constraint(node_t x, node_t y);
void init_c_phi(ir::instruction *i);
void init_c_graph(ir::instruction *v);
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
public:
axes();
void run(ir::module &mod);
unsigned get(ir::value *value, unsigned ax);
bool has(ir::value *value, unsigned ax);
private:
// constraints graph
graph_t dependencies_;
std::set<node_t> nodes_;
// parameter groups
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
};
}
}
}
#endif

View File

@@ -1,89 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_TUNE_H
#define TDL_INCLUDE_IR_CODEGEN_TUNE_H
#include <map>
#include <set>
#include <vector>
#include <memory>
namespace triton{
namespace ir{
class value;
class module;
class instruction;
class function;
class metaparameter;
class constant_int;
}
namespace codegen{
namespace transform{
class coalesce;
}
namespace analysis{
class grids {
typedef std::pair<ir::value*, unsigned> node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
typedef std::shared_ptr<int> param_ptr_t;
typedef std::map<ir::value*, std::map<int, param_ptr_t>> param_map_t;
public:
enum fragment_t{
STRIDED_SCAN,
HMMA_FRAGMENT_C
};
private:
void add_constraint(node_t x, node_t y);
void init_c_phi(ir::instruction *i);
void init_c_graph(ir::instruction *v);
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
void connected_components(node_t x, const std::vector<param_ptr_t>& params, const std::vector<param_map_t*>& maps, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
void create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, triton::ir::value *> &references,
ir::function *fn);
public:
grids(size_t num_warps, transform::coalesce* coalesce);
void run(ir::module &mod);
const std::vector<ir::value*> get() const { return grids_; }
fragment_t fragment_of(ir::value *value, unsigned ax);
unsigned group_of(ir::value *value, unsigned ax);
int mts(ir::value *value, unsigned ax);
int nts(ir::value *value, unsigned ax);
int fpw(ir::value *value, unsigned ax);
int wpt(ir::value *value, unsigned ax);
void copy(ir::value *dst, ir::value *src);
private:
transform::coalesce* coalesce_;
// number of warps
size_t num_warps_;
// grids
std::vector<ir::value*> grids_;
// grid parameters
param_map_t fpw_;
param_map_t wpt_;
param_map_t mts_;
param_map_t nts_;
// constraints graph
graph_t dependencies_;
std::set<node_t> nodes_;
// fragments
std::map<node_t, fragment_t> fragments_;
// parameter groups
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
};
}
}
}
#endif

View File

@@ -0,0 +1,57 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
#include <map>
#include <set>
#include <vector>
#include <memory>
namespace triton{
namespace ir{
class value;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
class axes;
class layout {
typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
private:
// connected components
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned id);
// list the axes of the given value
std::set<int> axes_of(ir::value *value);
public:
// constructor
layout(analysis::axes *axes);
// run the passes
void run(ir::module &mod);
// get the layout ID of the given value
unsigned id(ir::value *value) const;
// get the values associates with the given ID
const std::vector<ir::value*>& values(unsigned id) const;
// get number of groups
size_t get_num_groups() const;
private:
analysis::axes* axes_;
graph_t dependencies_;
std::set<node_t> nodes_;
std::map<ir::value*, unsigned> groups_;
std::map<unsigned, std::vector<ir::value*>> values_;
};
}
}
}
#endif

View File

@@ -15,15 +15,15 @@ namespace ir{
namespace codegen{
namespace analysis{
class grids;
class tiles;
class liveness;
class meminfo;
class memalloc {
public:
memalloc(liveness *live, meminfo *buffer_info, grids *params)
: liveness_(live), buffer_info_(buffer_info), params_(params){ }
memalloc(liveness *live, meminfo *buffer_info, tiles *params)
: liveness_(live), buffer_info_(buffer_info), tiles_(params){ }
// utilities
unsigned num_bytes(ir::value *x);
unsigned is_ld_padded(ir::value* x);
@@ -40,7 +40,7 @@ private:
// dependences
liveness *liveness_;
meminfo *buffer_info_;
grids *params_;
tiles *tiles_;
};
}

View File

@@ -0,0 +1,68 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_TILES_H_
#define _TRITON_CODEGEN_ANALYSIS_TILES_H_
#include <map>
#include <set>
#include <vector>
#include <memory>
namespace triton{
namespace ir{
class value;
class module;
class instruction;
class function;
class metaparameter;
class constant_int;
}
namespace codegen{
namespace transform{
class coalesce;
}
namespace analysis{
class axes;
class layout;
class tiles {
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
private:
void init_hmma_tile(ir::value *i);
void init_scanline_tile(ir::value *i);
public:
tiles(size_t num_warps, transform::coalesce* coalesce, analysis::axes* axes, analysis::layout* layout);
void run(ir::module &mod);
bool hmma(ir::value *value);
int mts(ir::value *value, unsigned ax);
int nts(ir::value *value, unsigned ax);
int fpw(ir::value *value, unsigned ax);
int wpt(ir::value *value, unsigned ax);
const std::map<int, ir::value*>& largest();
private:
// dependencies
analysis::layout* layout_;
analysis::axes* axes_;
transform::coalesce* coalesce_;
// number of warps
size_t num_warps_;
// tile properties
std::map<int, bool> hmma_;
std::map<int, ir::value*> largest_;
std::map<int, int> fpw_;
std::map<int, int> wpt_;
std::map<int, int> mts_;
std::map<int, int> nts_;
};
}
}
}
#endif

View File

@@ -43,10 +43,12 @@ namespace triton{
namespace codegen{
namespace analysis{
class grids;
class tiles;
class align;
class memalloc;
class meminfo;
class axes;
class layout;
}
namespace transform{
@@ -199,8 +201,12 @@ private:
public:
selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::coalesce* reorder, target *tgt, unsigned num_warps)
: alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ }
selection(analysis::memalloc *alloc, analysis::tiles *tiles, analysis::meminfo *buffer_info,
analysis::align *alignment, analysis::axes *axes, analysis::layout *layouts,
transform::coalesce* reorder, target *tgt, unsigned num_warps)
: alloc_(alloc), tiles_(tiles), buffer_info_(buffer_info),
alignment_(alignment), a_axes_(axes), layouts_(layouts),
reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ }
void run(ir::module &src, Module &dst);
@@ -208,7 +214,9 @@ private:
vmap_t vmap_;
tmap_t tmap_;
analysis::memalloc *alloc_;
analysis::grids *params_;
analysis::tiles *tiles_;
analysis::axes *a_axes_;
analysis::layout *layouts_;
analysis::meminfo *buffer_info_;
analysis::align *alignment_;
transform::coalesce *reorder_;

View File

@@ -19,7 +19,7 @@ class getelementptr_inst;
namespace codegen{
namespace analysis{
class grids;
class tiles;
class align;
}
@@ -37,11 +37,10 @@ private:
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
public:
reassociate(analysis::align* align, analysis::grids *params);
reassociate(analysis::align* align);
void run(ir::module& module);
private:
analysis::grids* params_;
analysis::align* align_;
};

View File

@@ -10,18 +10,18 @@ namespace ir {
namespace codegen{
namespace analysis{
class grids;
class tiles;
}
namespace transform{
class vectorize {
public:
vectorize(analysis::grids *params): params_(params){}
vectorize(analysis::tiles *params): params_(params){}
void run(ir::module &mod);
private:
analysis::grids *params_;
analysis::tiles *params_;
};
}

View File

@@ -63,6 +63,7 @@ public:
unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const;
const tile_shapes_t& get_tile_shapes() const;
const size_t get_tile_rank() const;
unsigned get_tile_num_elements() const;
type *get_tile_element_ty() const;
unsigned get_pointer_address_space() const;

View File

@@ -11,7 +11,7 @@
// codegen
#include "triton/codegen/selection.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/codegen/analysis/tiles.h"
#include "triton/codegen/analysis/memalloc.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/meminfo.h"
@@ -45,7 +45,7 @@ class translation_unit;
namespace codegen{
namespace analysis{
class grids;
class tiles;
}
}

View File

@@ -0,0 +1,166 @@
#include "triton/codegen/analysis/axes.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/context_impl.h"
#include "triton/ir/constant.h"
#include "triton/driver/device.h"
namespace triton{
namespace codegen{
namespace analysis{
axes::axes() {}
void axes::add_constraint(node_t x, node_t y) {
size_t shape_x = 1;
size_t shape_y = 1;
if(x.first->get_type()->is_tile_ty())
shape_x = x.first->get_type()->get_tile_shapes()[x.second];
if(y.first->get_type()->is_tile_ty())
shape_y = y.first->get_type()->get_tile_shapes()[y.second];
if(shape_x == 1 && shape_y == 1)
return;
dependencies_[x].insert(y);
dependencies_[y].insert(x);
nodes_.insert(x);
nodes_.insert(y);
}
void axes::init_c_graph(ir::instruction *v) {
// Reference shape
ir::type::tile_shapes_t shapes;
if(auto *store = dynamic_cast<ir::store_inst*>(v))
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
else if(auto *atom = dynamic_cast<ir::atomic_add_inst*>(v))
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
else if(dynamic_cast<ir::downcast_inst*>(v))
return;
else if(dynamic_cast<ir::copy_to_shared_inst*>(v))
return;
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v)) {
unsigned axis = reduce->get_axis();
ir::value *arg = reduce->get_operand(0);
auto in_shapes = arg->get_type()->get_tile_shapes();
unsigned current = 0;
for(unsigned i = 0; i < in_shapes.size(); i++){
if(i == axis)
continue;
add_constraint({reduce, current++}, {arg, i});
}
return;
}
else
shapes = v->get_type()->get_tile_shapes();
// Reshape
if(dynamic_cast<ir::reshape_inst*>(v)) {
ir::value *op = v->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
unsigned current = 0;
bool is_skewed = false;
for(unsigned i = 0; i < shapes.size(); i ++){
if(shapes[i] == 1){
add_constraint({v, i}, {v, i});
}
else if(!is_skewed &&
shapes[i] == op_shapes[current])
add_constraint({v, i}, {op, current++});
else{
is_skewed = true;
add_constraint({v, i}, {v, i});
}
}
}
// Splat
else if(dynamic_cast<ir::splat_inst*>(v)){
return;
}
// Trans
else if(auto *x = dynamic_cast<ir::trans_inst*>(v)){
ir::value *op = v->get_operand(0);
auto perm = x->get_perm();
for(unsigned i = 0; i < perm.size(); i++)
add_constraint({v, perm[i]->get_value()}, {op, i});
}
// Broadcast
else if(dynamic_cast<ir::broadcast_inst*>(v)){
ir::value *op = v->get_operand(0);
ir::type *op_ty = op->get_type();
const auto& op_shapes = op_ty->get_tile_shapes();
for(unsigned i = 0; i < shapes.size(); i ++){
if(op_shapes[i] == shapes[i] && v != op)
add_constraint({v, i}, {op, i});
}
}
// Matrix multiplication
else if(dynamic_cast<ir::dot_inst*>(v)){
ir::value *A = v->get_operand(0);
ir::value *B = v->get_operand(1);
ir::value *D = v->get_operand(2);
for(unsigned i = 0; i < shapes.size(); i++)
add_constraint({v, i}, {D, i});
for(unsigned i = 2; i < shapes.size(); i++){
add_constraint({v, i}, {A, i});
add_constraint({v, i}, {B, i});
}
}
// Element-wise
else if(dynamic_cast<ir::user*>(v)) {
for(unsigned i = 0; i < shapes.size(); i ++){
std::vector<ir::value*> ops = v->ops();
for(ir::value* op: ops)
add_constraint({v, i}, {op, i});
}
}
}
void axes::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
groups_[x.first].insert({x.second, group_id});
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
for(const node_t &y: graph[x])
connected_components(y, nodes, graph, group_id);
}
}
unsigned axes::get(ir::value *value, unsigned ax) {
unsigned result = groups_.at(value).at(ax);
return result;
}
bool axes::has(ir::value *value, unsigned ax) {
auto it = groups_.find(value);
if(it == groups_.end())
return false;
auto iit = it->second.find(ax);
if(iit == it->second.end())
return false;
return true;
}
void axes::run(ir::module &mod) {
nodes_.clear();
dependencies_.clear();
groups_.clear();
// Create graph
for(ir::function *fn: mod.get_function_list()){
// Build constraints graph
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
if(i->has_tile_result_or_op())
init_c_graph(i);
}
// Axes
unsigned group_id = 0;
while(!nodes_.empty())
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
}
}
}
}

View File

@@ -1,367 +0,0 @@
#include <algorithm>
#include <cstdlib>
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/context_impl.h"
#include "triton/ir/constant.h"
#include "triton/driver/device.h"
namespace triton{
namespace codegen{
namespace analysis{
grids::grids(size_t num_warps, transform::coalesce *reorder): num_warps_(num_warps), coalesce_(reorder)
{ }
bool is_hmma(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
// inputs have to be FP16
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
// reduction has to be multiple of 4: TODO
}
return result;
}
void grids::add_constraint(node_t x, node_t y) {
dependencies_[x].insert(y);
dependencies_[y].insert(x);
nodes_.insert(x);
nodes_.insert(y);
}
void grids::init_c_phi(ir::instruction *v) {
// Phi Nodes: all the incoming value share the result layout
if(auto *phi = dynamic_cast<ir::phi_node*>(v))
for(ir::value *op: phi->ops())
for(unsigned k = 0; k < phi->get_type()->get_tile_shapes().size(); k++)
if(dependencies_.find({op, k}) != dependencies_.end()
|| dependencies_.find({phi, k}) != dependencies_.end()){
add_constraint({phi, k}, {op, k});
}
}
void grids::init_c_graph(ir::instruction *v) {
// Reference shape
ir::type::tile_shapes_t shapes;
if(auto *store = dynamic_cast<ir::store_inst*>(v))
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
else if(auto *atom = dynamic_cast<ir::atomic_add_inst*>(v))
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
else if(dynamic_cast<ir::downcast_inst*>(v))
return;
else if(dynamic_cast<ir::copy_to_shared_inst*>(v))
return;
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v)) {
unsigned axis = reduce->get_axis();
ir::value *arg = reduce->get_operand(0);
auto in_shapes = arg->get_type()->get_tile_shapes();
unsigned current = 0;
for(unsigned i = 0; i < in_shapes.size(); i++){
if(i == axis)
continue;
add_constraint({reduce, current++}, {arg, i});
}
return;
}
else
shapes = v->get_type()->get_tile_shapes();
// Reshape
if(dynamic_cast<ir::reshape_inst*>(v)) {
ir::value *op = v->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
unsigned current = 0;
bool is_skewed = false;
for(unsigned i = 0; i < shapes.size(); i ++){
if(shapes[i] == 1){
add_constraint({v, i}, {v, i});
}
else if(!is_skewed &&
shapes[i] == op_shapes[current])
add_constraint({v, i}, {op, current++});
else{
is_skewed = true;
add_constraint({v, i}, {v, i});
}
}
}
// Splat
else if(dynamic_cast<ir::splat_inst*>(v)){
return;
}
// Trans
else if(auto *x = dynamic_cast<ir::trans_inst*>(v)){
ir::value *op = v->get_operand(0);
auto perm = x->get_perm();
for(unsigned i = 0; i < perm.size(); i++)
add_constraint({v, perm[i]->get_value()}, {op, i});
}
// Broadcast
else if(dynamic_cast<ir::broadcast_inst*>(v)){
ir::value *op = v->get_operand(0);
ir::type *op_ty = op->get_type();
const auto& op_shapes = op_ty->get_tile_shapes();
for(unsigned i = 0; i < shapes.size(); i ++){
if(op_shapes[i] == shapes[i] && v != op)
add_constraint({v, i}, {op, i});
}
}
// Matrix multiplication
else if(dynamic_cast<ir::dot_inst*>(v)){
ir::value *A = v->get_operand(0);
ir::value *B = v->get_operand(1);
ir::value *D = v->get_operand(2);
for(unsigned i = 0; i < shapes.size(); i++)
add_constraint({v, i}, {D, i});
for(unsigned i = 2; i < shapes.size(); i++){
add_constraint({v, i}, {A, i});
add_constraint({v, i}, {B, i});
}
}
// Element-wise
else if(dynamic_cast<ir::user*>(v)) {
for(unsigned i = 0; i < shapes.size(); i ++){
std::vector<ir::value*> ops = v->ops();
for(ir::value* op: ops)
add_constraint({v, i}, {op, i});
}
}
}
grids::fragment_t grids::get_fragmentation_type(node_t x, graph_t &graph){
std::list<node_t> work;
std::set<node_t> seen;
work.push_back(x);
while(!work.empty()){
node_t current = work.back();
if(is_hmma(current.first))
return HMMA_FRAGMENT_C;
work.pop_back();
seen.insert(current);
for(node_t y: graph[current]){
if(seen.find(y) == seen.end())
work.push_back(y);
}
}
return STRIDED_SCAN;
}
void grids::connected_components(node_t x, const std::vector<param_ptr_t>& ptr_vec, const std::vector<param_map_t*>& maps,
std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
groups_[x.first].insert({x.second, group_id});
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
for(unsigned i = 0; i < ptr_vec.size(); i++)
(*maps[i])[x.first][x.second] = ptr_vec[i];
for(const node_t &y: graph[x])
connected_components(y, ptr_vec, maps, nodes, graph, group_id);
}
}
unsigned grids::group_of(ir::value *value, unsigned ax) {
unsigned result = groups_.at(value).at(ax);
return result;
}
grids::fragment_t grids::fragment_of(ir::value *value, unsigned ax) {
return fragments_.at({value, ax});
}
//TODO: This shouldn't exist!
void grids::copy(ir::value *dst, ir::value *src) {
mts_[dst] = mts_[src];
nts_[dst] = nts_[src];
fpw_[dst] = fpw_[src];
wpt_[dst] = wpt_[src];
groups_[dst] = groups_[src];
fragments_[{dst, 0}] = fragments_[{src, 0}];
}
void grids::run(ir::module &mod) {
// Create tiling parameters
for(ir::function *fn: mod.get_function_list()){
// Build constraints graph
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
if(i->has_tile_result_or_op())
init_c_graph(i);
// Build phi constraints
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
if(i->has_tile_result_or_op())
init_c_phi(i);
// Layout parameters
unsigned group_id = 0;
for(auto x: nodes_)
fragments_[x] = get_fragmentation_type(x, dependencies_);
while(!nodes_.empty()) {
node_t node = *nodes_.begin();
if(fragments_[node] == STRIDED_SCAN) {
param_ptr_t nts(new int(-1));
param_ptr_t mts(new int(-1));
connected_components(node, {nts, mts}, {&nts_, &mts_}, nodes_, dependencies_, group_id++);
}
else {
param_ptr_t fpw(new int(-1));
param_ptr_t wpt(new int(-1));
connected_components(node, {fpw, wpt}, {&fpw_, &wpt_}, nodes_, dependencies_, group_id++);
}
}
}
for(ir::function *fn: mod.get_function_list()){
std::map<unsigned, ir::value*> references;
create_grids(grids_, references, fn);
}
unsigned num_threads = num_warps_*32;
auto clamp = [&](unsigned x, unsigned lo, unsigned hi) { return std::min(std::max(x, lo), hi); };
for(ir::value *i: grids_){
if(!i->get_type()->is_tile_ty())
continue;
auto order = coalesce_->get_order(i);
auto shapes = i->get_type()->get_tile_shapes();
unsigned size = i->get_type()->get_tile_num_elements();
/* HMMA parameters*/
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
unsigned shape_0 = shapes[order[0]];
unsigned shape_1 = shapes[order[1]];
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
std::vector<unsigned> fpw = {1, 1, 1};
std::vector<unsigned> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_0/8)*(shape_1/8), 4);
do {
fpw_nm1 = fpw;
if(fpw[0]*fpw[1] < num_fragments)
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
if(fpw[0]*fpw[1] < num_fragments)
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
}while(fpw_nm1 != fpw);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
*fpw_[i][d] = fpw[d];
/* warps per tile */
// try to make things as square as possible to maximize data re-use
std::vector<unsigned> wpt = {1, 1, 1};
std::vector<unsigned> wpt_nm1;
do{
wpt_nm1 = wpt;
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
}while(wpt_nm1 != wpt);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
*wpt_[i][d] = wpt[d];
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_warps *= *wpt_[i][d];
if(num_warps_ != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
/* Scan-line */
else{
unsigned ld = order[0];
unsigned current = num_threads;
*nts_[i][ld] = clamp(size / num_threads, 1, 4);
*mts_[i][ld] = clamp(current, 1, shapes[ld] / *nts_[i][ld]);
current = current / *mts_[i][ld];
for(size_t d = 1; d < shapes.size(); d++){
ld = order[d];
*nts_[i][ld] = 1;
*mts_[i][ld] = clamp(current, 1, shapes[ld]);
current = current / *mts_[i][ld];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_threads *= *mts_[i][d];
if(num_threads != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
}
}
void grids::create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, ir::value*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){
unsigned result = 0;
for(auto shape: v->get_type()->get_tile_shapes()) {
result += (shape > 1)? shape : 0;
}
return result;
};
// bind references
std::set<ir::value*> seen;
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
{
// skip
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
return;
// recurse
if(auto *user = dynamic_cast<ir::user*>(v))
for(ir::value *op: user->ops())
bind_references(op);
// bind
const auto& shapes = v->get_type()->get_tile_shapes();
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] == 1)
continue;
unsigned x = group_of(v, d);
ir::value *&r = references[x];
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
r = v;
}
};
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
bind_references(i);
// create grid
for(auto &ref: references)
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
grids.push_back(ref.second);
}
int grids::mts(ir::value *value, unsigned ax) {
return *mts_.at(value).at(ax);
}
int grids::nts(ir::value *value, unsigned ax) {
return *nts_.at(value).at(ax);
}
int grids::fpw(ir::value *value, unsigned ax) {
return *fpw_.at(value).at(ax);
}
int grids::wpt(ir::value *value, unsigned ax) {
return *wpt_.at(value).at(ax);
}
}
}
}

View File

@@ -0,0 +1,96 @@
#include <algorithm>
#include <iostream>
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
namespace triton{
namespace codegen{
namespace analysis{
// axes
std::set<int> layout::axes_of(ir::value *value) {
auto ty = value->get_type();
// rank of value
size_t rank = 0;
if(ty->is_tile_ty())
rank = ty->get_tile_rank();
// create result
std::set<int> result;
for(size_t d = 0; d < rank; d++){
if(axes_->has(value, d))
result.insert(axes_->get(value, d));
}
return result;
}
// connected components
void layout::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
groups_[x] = group_id;
values_[group_id].push_back(x);
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
for(const node_t &y: graph[x])
connected_components(y, nodes, graph, group_id);
}
}
// constructor
layout::layout(analysis::axes *axes)
: axes_(axes) { }
// get group id
unsigned layout::id(ir::value *value) const
{ return groups_.at(value); }
// get values
const std::vector<ir::value*>& layout::values(unsigned id) const
{ return values_.at(id); }
// get number of groups
size_t layout::get_num_groups() const
{ return values_.size(); }
// run
void layout::run(ir::module &mod) {
nodes_.clear();
dependencies_.clear();
groups_.clear();
values_.clear();
// Create graph
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list()) {
// skip scalars
if(!i->get_type()->is_tile_ty())
continue;
// add an edge between i and the operands that share an axis
std::set<int> i_axes = axes_of(i);
nodes_.insert(i);
for(ir::value* op: i->ops()){
if(!op->get_type()->is_tile_ty())
continue;
nodes_.insert(op);
std::set<int> op_axes = axes_of(op);
std::set<int> common;
std::set_intersection(i_axes.begin(), i_axes.end(),
op_axes.begin(), op_axes.end(),
std::inserter(common, common.begin()));
if(!common.empty() || !op->get_type()->is_tile_ty()){
dependencies_[i].insert(op);
dependencies_[op].insert(i);
}
}
}
// Grids
unsigned group_id = 0;
while(!nodes_.empty()){
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
}
}
}
}
}

View File

@@ -2,7 +2,7 @@
#include "triton/codegen/analysis/memalloc.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/meminfo.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/codegen/analysis/tiles.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/type.h"
#include "triton/ir/value.h"
@@ -20,7 +20,7 @@ unsigned memalloc::is_ld_padded(ir::value *x) {
}
for(ir::user* user: x->get_users())
if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
bool is_hmma = params_->fragment_of(user, 0) == grids::HMMA_FRAGMENT_C;
bool is_hmma = tiles_->hmma(user);
bool is_op_0 = x == dot->get_operand(0);
bool is_op_1 = x == dot->get_operand(1);
if(is_hmma && is_op_0){
@@ -56,10 +56,10 @@ unsigned memalloc::num_bytes(ir::value *x) {
for(auto x: shapes)
num_elements *= x;
size_t depth;
if(params_->fragment_of(x, 0) == grids::HMMA_FRAGMENT_C)
depth = params_->wpt(op, axis);
if(tiles_->hmma(x))
depth = tiles_->wpt(op, axis);
else
depth = params_->mts(op, axis);
depth = tiles_->mts(op, axis);
return num_elements * num_bytes * depth;
}
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;

View File

@@ -0,0 +1,176 @@
#include <algorithm>
#include <cstdlib>
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/tiles.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/context_impl.h"
#include "triton/ir/constant.h"
#include "triton/driver/device.h"
namespace triton{
namespace codegen{
namespace analysis{
tiles::tiles(size_t num_warps, transform::coalesce *reorder, analysis::axes *axes, analysis::layout *layout):
num_warps_(num_warps), coalesce_(reorder), axes_(axes), layout_(layout)
{ }
bool is_hmma(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = a_ty->get_scalar_ty()->is_half_ty() &&
b_ty->get_scalar_ty()->is_half_ty();
}
return result;
}
bool tiles::hmma(ir::value *value) {
return hmma_.at(layout_->id(value));
}
int tiles::mts(ir::value *value, unsigned ax) {
return mts_.at(axes_->get(value, ax));
}
int tiles::nts(ir::value *value, unsigned ax) {
return nts_.at(axes_->get(value, ax));
}
int tiles::fpw(ir::value *value, unsigned ax) {
return fpw_.at(axes_->get(value, ax));
}
int tiles::wpt(ir::value *value, unsigned ax) {
return wpt_.at(axes_->get(value, ax));
}
const std::map<int, ir::value*>& tiles::largest() {
return largest_;
}
unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
return std::min(std::max(x, lo), hi);
}
void tiles::init_hmma_tile(ir::value *i) {
auto order = coalesce_->get_order(i);
auto shapes = i->get_type()->get_tile_shapes();
unsigned shape_0 = shapes[order[0]];
unsigned shape_1 = shapes[order[1]];
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
std::vector<unsigned> fpw = {1, 1, 1};
std::vector<unsigned> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_0/8)*(shape_1/8), 4);
do {
fpw_nm1 = fpw;
if(fpw[0]*fpw[1] < num_fragments)
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
if(fpw[0]*fpw[1] < num_fragments)
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
}while(fpw_nm1 != fpw);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
fpw_[axes_->get(i, d)] = fpw[d];
/* warps per tile */
// try to make things as square as possible to maximize data re-use
std::vector<unsigned> wpt = {1, 1, 1};
std::vector<unsigned> wpt_nm1;
do{
wpt_nm1 = wpt;
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
}while(wpt_nm1 != wpt);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
wpt_[axes_->get(i, d)] = wpt[d];
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_warps *= wpt_[axes_->get(i, d)];
if(num_warps_ != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
void tiles::init_scanline_tile(ir::value *i) {
auto order = coalesce_->get_order(i);
auto shapes = i->get_type()->get_tile_shapes();
unsigned size = i->get_type()->get_tile_num_elements();
unsigned ld = order[0];
unsigned num_threads = num_warps_*32;
unsigned current = num_threads;
nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4);
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]);
current = current / mts_[axes_->get(i, ld)];
for(size_t d = 1; d < shapes.size(); d++){
ld = order[d];
nts_[axes_->get(i, ld)] = 1;
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]);
current = current / mts_[axes_->get(i, ld)];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_threads *= mts_[axes_->get(i, d)];
if(num_threads != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
void tiles::run(ir::module &) {
hmma_.clear();
largest_.clear();
size_t num_groups = layout_->get_num_groups();
// find out which groups require hmma layout
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i);
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma);
}
// find out which value is the largest in each group
// std::vector<unsigned> axes;
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i);
auto rank = [](ir::value* v) {
ir::type *ty = v->get_type();
size_t ret = 0;
if(ty->is_tile_ty())
for(int s: ty->get_tile_shapes())
ret += s > 1;
return ret;
};
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
}
// tiling parameters
for(auto x: largest_){
ir::value *i = x.second;
if(!i->get_type()->is_tile_ty())
continue;
/* HMMA parameters*/
if(hmma_[x.first])
init_hmma_tile(i);
else
init_scanline_tile(i);
}
}
}
}
}

View File

@@ -1,6 +1,8 @@
#include "triton/codegen/selection.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/tiles.h"
#include "triton/codegen/analysis/memalloc.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/coalesce.h"
@@ -584,8 +586,8 @@ void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value
std::vector<unsigned> warp_size(dim);
std::vector<unsigned> n_warps(dim);
for(unsigned i = 0; i < shapes.size(); i++){
contiguous[i] = params_->nts(v, i);
block_size[i] = params_->mts(v, i);
contiguous[i] = tiles_->nts(v, i);
block_size[i] = tiles_->mts(v, i);
}
to_warps(block_size, order, n_warps, warp_size);
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
@@ -604,7 +606,7 @@ void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[params_->group_of(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
axes_[a_axes_->get(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
}
}
@@ -622,13 +624,13 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
Value *_16 = builder.getInt32(16);
// fragments per warp
unsigned fpw_0 = params_->fpw(v, 0);
unsigned fpw_1 = params_->fpw(v, 1);
unsigned fpw_2 = is_batched ? params_->fpw(v, 2) : 1;
unsigned fpw_0 = tiles_->fpw(v, 0);
unsigned fpw_1 = tiles_->fpw(v, 1);
unsigned fpw_2 = is_batched ? tiles_->fpw(v, 2) : 1;
// warps per tile
unsigned wpt_0 = params_->wpt(v, 0);
unsigned wpt_1 = params_->wpt(v, 1);
unsigned wpt_2 = is_batched ? params_->wpt(v, 2) : 1;
unsigned wpt_0 = tiles_->wpt(v, 0);
unsigned wpt_1 = tiles_->wpt(v, 1);
unsigned wpt_2 = is_batched ? tiles_->wpt(v, 2) : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
@@ -709,18 +711,18 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
/* axes */
axes_[params_->group_of(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[params_->group_of(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
axes_[a_axes_->get(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[a_axes_->get(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[params_->group_of(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
axes_[a_axes_->get(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
}
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
if(params_->fragment_of(v, 0) == analysis::grids::STRIDED_SCAN)
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
else
if(tiles_->hmma(v))
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
else
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
}
/* -------------------
@@ -780,7 +782,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned x = params_->group_of(v, d);
unsigned x = a_axes_->get(v, d);
axes[d] = axes_.at(x);
}
else{
@@ -831,8 +833,8 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid
for(ir::value* i: params_->get())
init_axes(i, builder, u_thread_warp_id, u_warp_id);
for(auto x: tiles_->largest())
init_axes(x.second, builder, u_thread_warp_id, u_warp_id);
// create tile
std::set<ir::value*> seen;
for(ir::basic_block *block: fn->blocks())
@@ -915,7 +917,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
for(auto& x: partial) {
// current element being computed
Value *lane = axes_.at(params_->group_of(op, axis)).thread_id;
Value *lane = axes_.at(a_axes_->get(op, axis)).thread_id;
Value *&result = x.second;
indices_t write_idx = x.first;
write_idx.insert(write_idx.begin() + axis, lane);
@@ -928,7 +930,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
tgt_->add_barrier(module, builder);
builder.CreateStore(result, write_ptr);
// build result
unsigned depth = params_->wpt(op, axis);
unsigned depth = tiles_->wpt(op, axis);
for(unsigned i = depth/2; i > 0; i >>= 1){
// current indices
indices_t current(write_idx.size(), builder.getInt32(0));
@@ -1095,12 +1097,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
"{$10, $11}, "
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
unsigned fpw_0 = params_->fpw(dot, 0);
unsigned fpw_1 = params_->fpw(dot, 1);
unsigned fpw_0 = tiles_->fpw(dot, 0);
unsigned fpw_1 = tiles_->fpw(dot, 1);
unsigned wts_0 = fpw_0 * 8;
unsigned wts_1 = fpw_1 * 8;
unsigned wpt_0 = params_->wpt(dot, 0);
unsigned wpt_1 = params_->wpt(dot, 1);
unsigned wpt_0 = tiles_->wpt(dot, 0);
unsigned wpt_1 = tiles_->wpt(dot, 1);
unsigned stride_rep_i = wpt_0 * wts_0;
unsigned stride_rep_j = wpt_1 * wts_1;
unsigned num_rep_i = shapes[0] / stride_rep_i;
@@ -1241,10 +1243,11 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(params_->fragment_of(dot, 0) == analysis::grids::STRIDED_SCAN)
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
else
if(tiles_->hmma(dot))
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
else
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
}
else {
distributed_tile *TA = (distributed_tile*)tmap_.at(A);

View File

@@ -2,7 +2,6 @@
#include <iostream>
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -90,11 +89,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
new_rhs = builder.create_splat(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
}
if(new_value != old_value){
params_->copy(new_value, old_value);
params_->copy(new_lhs, old_value);
params_->copy(new_rhs, old_value);
}
}
}
@@ -127,11 +121,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
if(is_cst(rrhs))
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
}
if(new_value != old_value){
params_->copy(new_value, old_value);
params_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
params_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
}
}
// extract constant and non-constant
@@ -156,8 +145,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
return new_value;
}
reassociate::reassociate(analysis::align *align, analysis::grids* params)
: params_(params), align_(align)
reassociate::reassociate(analysis::align *align): align_(align)
{ }
@@ -183,9 +171,6 @@ void reassociate::run(ir::module &mod) {
ir::value* static_range = ir::make_range_sta::get(old_range);
ir::value* new_range = builder.create_add(dyn_range, static_range);
old_range->replace_all_uses_with(new_range);
params_->copy(dyn_range, old_range);
params_->copy(static_range, old_range);
params_->copy(new_range, old_range);
}
}
@@ -214,9 +199,6 @@ void reassociate::run(ir::module &mod) {
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
ir::value* broadcast = builder.create_broadcast(cst, shapes);
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
params_->copy(ndyn, rt);
params_->copy(nsta, rt);
params_->copy(broadcast, rt);
infos[rt] = cst_info{ndyn, nsta};
}
}
@@ -236,8 +218,6 @@ void reassociate::run(ir::module &mod) {
builder.set_insert_point(pz);
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
params_->copy(dyn_ptr, pz);
params_->copy(sta_ptr, pz);
pz->replace_all_uses_with(sta_ptr);
infos[sta_ptr].dyn_ptr = dyn_ptr;
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
@@ -252,8 +232,6 @@ void reassociate::run(ir::module &mod) {
ir::value *off = *pz->idx_begin();
ir::value *pz_dyn = builder.create_gep(dyn, {off});
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
params_->copy(pz_dyn, pz);
params_->copy(pz_sta, pz);
pz->replace_all_uses_with(pz_sta);
infos[pz_sta].dyn_ptr = pz_dyn;
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
@@ -298,12 +276,6 @@ void reassociate::run(ir::module &mod) {
ir::value *neg_off = builder.create_neg(off);
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
// copy parameters
params_->copy(pz_dyn, pz);
params_->copy(((ir::instruction*)neg_off)->get_operand(0), off);
params_->copy(neg_off, off);
params_->copy(phi_dyn, phi);
params_->copy(phi_sta, phi);
infos[phi_sta].dyn_ptr = phi_dyn;
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
replaced.insert(phi);

View File

@@ -1,5 +1,5 @@
#include "triton/codegen/transform/vectorize.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/codegen/analysis/tiles.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -23,7 +23,6 @@ void vectorize::run(ir::module &mod) {
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
x->replace_all_uses_with(rx);
rx->set_operand(0, x);
params_->copy(rx, x);
}
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
ir::value *x = i->get_operand(0);
@@ -33,7 +32,6 @@ void vectorize::run(ir::module &mod) {
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
x->replace_all_uses_with(rx);
rx->set_operand(0, x);
params_->copy(rx, x);
}
}
}

View File

@@ -73,6 +73,10 @@ const type::tile_shapes_t &type::get_tile_shapes() const {
return ((tile_type*)this)->get_shapes();
}
const size_t type::get_tile_rank() const {
return get_tile_shapes().size();
}
unsigned type::get_tile_num_elements() const {
const tile_shapes_t& shapes = get_tile_shapes();
unsigned result = 1;

View File

@@ -3,6 +3,9 @@
#include <regex>
#include <functional>
#include <algorithm>
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/tiles.h"
#include "triton/codegen/selection.h"
#include "triton/runtime/function.h"
#include "triton/codegen/transform/coalesce.h"
@@ -192,49 +195,54 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::context *context, const options_t& opt) {
std::unique_ptr<codegen::target> target = context->device()->make_target();
// generate llvm code
llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
// create passes
codegen::analysis::meminfo shmem_info;
codegen::analysis::liveness shmem_liveness(&shmem_info);
codegen::analysis::align alignment_info;
codegen::analysis::liveness shmem_liveness(&shmem_info);
codegen::transform::coalesce coalesce(&alignment_info, &shmem_info);
codegen::analysis::grids grids(opt.num_warps, &coalesce);
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
codegen::analysis::axes axes;
codegen::analysis::layout layouts(&axes);
codegen::analysis::tiles tiles(opt.num_warps, &coalesce, &axes, &layouts);
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &tiles);
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
codegen::transform::vectorize vectorize(&grids);
codegen::transform::vectorize vectorize(&tiles);
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&alignment_info, &grids);
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, &coalesce, target.get(), opt.num_warps);
codegen::transform::reassociate reassociate(&alignment_info);
codegen::selection selection(&shmem_allocation, &tiles, &shmem_info, &alignment_info, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
// run passes
peephole.run(module);
dce.run(module);
alignment_info.run(module);
if(target->is_gpu())
shmem_info.run(module);
shmem_info.run(module);
coalesce.run(module);
dce.run(module);
grids.run(module);
axes.run(module);
layouts.run(module);
tiles.run(module);
alignment_info.run(module);
reassociate.run(module);
dce.run(module);
peephole.run(module);
if(target->is_gpu()){
shmem_info.run(module);
shmem_liveness.run(module);
shmem_allocation.run();
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>();
shmem_barriers.run(module);
}
shmem_info.run(module);
shmem_liveness.run(module);
shmem_allocation.run();
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>();
shmem_barriers.run(module);
dce.run(module);
vectorize.run(module);
dce.run(module);
alignment_info.run(module);
coalesce.run(module);
dce.run(module);
// ir::print(module, std::cout);
// generate llvm code
llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
axes.run(module);
layouts.run(module);
tiles.run(module);
selection.run(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));

View File

@@ -45,10 +45,10 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"AT", {AT?"1":"0"}});
opt.defines.push_back({"BT", {BT?"1":"0"}});
opt.defines.push_back({"TM", {"64", "128"}});
opt.defines.push_back({"TN", {"64", "128"}});
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {2, 4, 8};
opt.num_warps = {4};
// create function
rt::function function(src::dot, opt);
// benchmark available libraries