History prior to this date belonged to the now deprecated ISAAC project, and was deleted to save space

This commit is contained in:
Philippe Tillet
2021-07-27 12:38:38 -07:00
commit 6d7cf35123
202 changed files with 94034 additions and 0 deletions

View File

@@ -0,0 +1,514 @@
#include "triton/codegen/analysis/align.h"
#include "triton/ir/utils.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace analysis{
// Function for extended Euclidean Algorithm
int gcd_impl(int a, int b, int *x, int *y)
{
// Base Case
if (a == 0)
{
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b%a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b/a) * x1;
*y = x1;
return gcd;
}
int gcd(int a, int b) {
int x, y;
return gcd_impl(a, b, &x, &y);
}
inline int lcm(int a, int b) {
return (a * b) / gcd(a, b);
}
template<class T>
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
return map[i] = value;
}
/*
* is constant
*/
std::vector<unsigned> align::get_shapes(ir::value *v) {
ir::type *ty = v->get_type();
if(ty->is_tile_ty())
return ty->get_tile_shapes();
else
return {1};
}
std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
auto shapes = get_shapes(x);
std::vector<cst_info> result(shapes.size(), cst_info{1, 0});
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto it = is_constant_.find(inc);
if(it != is_constant_.end())
result = it->second;
}
return add_to_cache(x, result, is_constant_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto cst = populate_is_constant(inc);
for(size_t d = 0; d < cst.size(); d++)
result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
auto shapes = get_shapes(x);
ir::value* op = x->get_operand(0);
std::vector<cst_info> result;
auto op_cst = populate_is_constant(op);
for(auto d: shapes)
result.push_back(cst_info{d, op_cst[0].value});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_cst = populate_is_constant(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < x_shapes.size(); d ++){
cst_info ax ;
if(x_shapes[d] == 1)
ax = {1, op_cst[current].value};
else if(!is_skewed
&& x_shapes[d] == op_shapes[current])
ax = {x_shapes[d], op_cst[current++].value};
else {
is_skewed = true;
ax = {x_shapes[d], 0};
}
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_cst = populate_is_constant(op);
for(size_t d = 0; d < x_shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(cst_info{x_shapes[d], op_cst[d].value});
else
result.push_back(op_cst[d]);
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
auto max_contiguous = populate_max_contiguous(lhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax;
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
// todo might not be entirely true
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
ax = {num_constants, 0};
}
else
ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0};
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) {
auto x_shapes = get_shapes(x);
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
std::vector<cst_info> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) {
auto shapes = get_shapes(v);
std::vector<cst_info> result(shapes.size(), {1, 0});
return add_to_cache(v, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
if(is_constant_.find(v) != is_constant_.end())
return is_constant_.at(v);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
if(dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_is_constant_phi(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_is_constant_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_is_constant_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_is_constant_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_is_constant_binop(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_is_constant_gep(x);
return populate_is_constant_default(v);
}
/*
* max contiguous
*/
std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result(shapes.size(), 1);
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto it = max_contiguous_.find(inc);
if(it != max_contiguous_.end())
result = it->second;
}
add_to_cache(x, result, max_contiguous_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto contiguous = populate_max_contiguous(inc);
for(size_t d = 0; d < result.size(); d++)
result[d] = std::min(result[d], contiguous[d]);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<unsigned> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({1});
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_mc = populate_max_contiguous(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result.push_back(1);
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result.push_back(op_mc[current++]);
else {
is_skewed = true;
result.push_back(1);
}
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_mc = populate_max_contiguous(op);
for(size_t d = 0; d < shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(1);
else
result.push_back(op_mc[d]);
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) {
auto shapes = get_shapes(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
auto lhs_max_contiguous = populate_max_contiguous(lhs);
auto rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
std::vector<unsigned> result;
for(size_t d = 0; d < shapes.size(); d++){
unsigned value = 1;
if(x->is_int_rem() && rhs_cst_info[d].value > 0)
value = std::min(lhs_max_contiguous[d], rhs_cst_info[d].value);
if(x->is_int_mult()){
unsigned lvalue = 1, rvalue = 1;
if(rhs_cst_info[d].value == 1)
lvalue = lhs_max_contiguous[d];
if(lhs_cst_info[d].value == 1)
rvalue = rhs_max_contiguous[d];
value = std::max(lvalue, rvalue);
}
if(x->is_int_add_sub()){
unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst > 0)
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
if(rhs_cst_info[d].num_cst > 0)
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
value = std::max(lvalue, rvalue);
}
result.push_back(value);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) {
auto shapes = get_shapes(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
auto lhs_max_contiguous = populate_max_contiguous(lhs);
auto rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
std::vector<unsigned> result(shapes.size(), 1);
for(size_t d = 0; d < shapes.size(); d++){
unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst)
lvalue = rhs_max_contiguous[d];
if(rhs_cst_info[d].num_cst)
rvalue = lhs_max_contiguous[d];
result[d] = std::max(lvalue, rvalue);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
if(!v->get_type()->is_tile_ty())
return add_to_cache(v, {1}, max_contiguous_);
auto shapes = v->get_type()->get_tile_shapes();
if(dynamic_cast<ir::make_range*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
if(dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_max_contiguous_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_max_contiguous_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_max_contiguous_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_max_contiguous_binop(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_max_contiguous_gep(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_max_contiguous_phi(x);
return populate_max_contiguous_default(v);
}
/*
* starting multiple
*/
std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){
auto shapes = get_shapes(x);
auto op = populate_starting_multiple(x->get_operand(0));
std::vector<unsigned> result(shapes.size(), op[0]);
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){
auto op = populate_starting_multiple(x->get_operand(0));
auto op_shapes = get_shapes(x->get_operand(0));
auto shapes = get_shapes(x);
std::vector<unsigned> result(shapes.size(), 1);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result[d] = 1;
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result[d] = op[current++];
else {
is_skewed = true;
result[d] = 1;
}
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++){
if(x->is_int_mult())
result[d] = lhs[d] * rhs[d];
if(x->is_int_add_sub())
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div())
result[d] = std::max<unsigned>(lhs[d] / rhs[d], 1);
if(x->is_int_rem() && rhs[d] > 1)
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_shl())
result[d] = lhs[d] << rhs[d];
if(x->is_shr())
result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++)
result[d] = gcd(lhs[d], rhs[d]);
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
auto shape = get_shapes(x);
std::vector<unsigned> result(shape.size(), 1);
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
if(starting_multiple_.find(inc) != starting_multiple_.end())
result = starting_multiple_.at(inc);
}
add_to_cache(x, result, starting_multiple_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto sm = populate_starting_multiple(inc);
for(size_t d = 0; d < result.size(); d++)
result[d] = gcd(result[d], sm[d]);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
ir::type* ty = v->get_type();
if(ty->is_tile_ty()) {
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
}
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return add_to_cache(x, {multiple_of}, starting_multiple_);
}
if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
for(auto attr: attributes){
if(attr.get_kind() == ir::multiple_of){
return add_to_cache(x, {attr.get_value()}, starting_multiple_);
}
if(attr.get_kind() == ir::aligned){
ir::type* ty = x->get_type()->get_pointer_element_ty();
int nbits = ty->get_primitive_size_in_bits();
int nbytes = nbits / 8;
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
}
}
}
return add_to_cache(v, {1}, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_starting_multiple_binop(x);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range*>(v))
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range_dyn*>(v))
return add_to_cache(x, {128}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(x, {(unsigned)x->get_range()->get_first()->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_starting_multiple_gep(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_starting_multiple_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_starting_multiple_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_starting_multiple_broadcast(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_starting_multiple_phi(x);
return populate_starting_multiple_default(v);
}
unsigned align::get(ir::value *v, unsigned ax) const {
unsigned starting_multiple = starting_multiple_.at(v)[ax];
unsigned max_contiguous = max_contiguous_.at(v)[ax];
return std::min(starting_multiple, max_contiguous);
}
std::vector<unsigned> align::contiguous(ir::value* v) const {
return max_contiguous_.at(v);
}
void align::populate(ir::value *v) {
populate_is_constant(v);
populate_starting_multiple(v);
populate_max_contiguous(v);
}
void align::run(ir::module &mod) {
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
}
}
}
}

View File

@@ -0,0 +1,107 @@
#include <algorithm>
#include <climits>
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace analysis{
void allocation::run(ir::module &mod) {
using std::max;
using std::min;
typedef std::multimap<unsigned, segment> triples_map_type;
std::vector<shared_layout*> I;
for(auto x: liveness_->get())
I.push_back(x.first);
std::vector<shared_layout*> J = I;
triples_map_type H;
H.insert({0, segment{0, INT_MAX}});
std::vector<shared_layout*> V;
std::map<shared_layout*, unsigned> starts;
while(!J.empty()){
auto h_it = H.begin();
unsigned w = h_it->first;
segment xh = h_it->second;
H.erase(h_it);
auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){
segment xj = liveness_->get(JJ);
bool res = xj.intersect(xh);
for(auto val: H)
res = res && !val.second.intersect(xj);
return res;
});
if(j_it != J.end()){
unsigned size = (*j_it)->get_size();
segment xj = liveness_->get(*j_it);
starts[*j_it] = w;
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
if(xh.start < xj.start)
H.insert({w, segment{xh.start, xj.end}});
if(xj.end < xh.end)
H.insert({w, segment{xj.start, xh.end}});
V.push_back(*j_it);
J.erase(j_it);
}
}
// Build interference graph
std::map<shared_layout*, std::set<shared_layout*>> interferences;
for(shared_layout* x: V)
for(shared_layout* y: V){
if(x == y)
continue;
unsigned X0 = starts[x], Y0 = starts[y];
unsigned NX = x->get_size();
unsigned NY = y->get_size();
segment XS = {X0, X0 + NX};
segment YS = {Y0, Y0 + NY};
if(liveness_->get(x).intersect(liveness_->get(y))
&& XS.intersect(YS))
interferences[x].insert(y);
}
// Initialize colors
std::map<shared_layout*, int> colors;
for(shared_layout* X: V)
colors[X] = (X==V[0])?0:-1;
// First-fit graph coloring
std::vector<bool> available(V.size());
for(shared_layout* x: V){
// Non-neighboring colors are available
std::fill(available.begin(), available.end(), true);
for(shared_layout* Y: interferences[x]){
int color = colors[Y];
if(color >= 0)
available[color] = false;
}
// Assigns first available color
auto It = std::find(available.begin(), available.end(), true);
colors[x] = std::distance(available.begin(), It);
}
// Finalize allocation
for(shared_layout* x: V){
unsigned Adj = 0;
for(shared_layout* y: interferences[x])
Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
offsets_[x] = starts[x] + colors[x] * Adj;
}
// Save maximum size of induced memory space
allocated_size_ = 0;
for(shared_layout* x: V)
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
}
}
}
}

View File

@@ -0,0 +1,147 @@
#include "triton/codegen/analysis/axes.h"
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
namespace triton{
namespace codegen{
namespace analysis{
axes::axes() {}
void axes::update_graph_reduce(ir::instruction *i) {
auto* red = static_cast<ir::reduce_inst*>(i);
unsigned axis = red->get_axis();
ir::value *arg = red->get_operand(0);
auto in_shapes = arg->get_type()->get_tile_shapes();
unsigned current = 0;
for(unsigned d = 0; d < in_shapes.size(); d++){
if(d == axis)
continue;
graph_.add_edge({i, current++}, {arg, d});
}
}
void axes::update_graph_reshape(ir::instruction *i) {
auto* reshape = static_cast<ir::reshape_inst*>(i);
// operands
ir::value *op = reshape->get_operand(0);
// shapes
auto op_shapes = op->get_type()->get_tile_shapes();
auto res_shapes = reshape->get_type()->get_tile_shapes();
// construct edges
unsigned current = 0;
bool is_skewed = false;
for(unsigned d = 0; d < res_shapes.size(); d ++){
bool same_shape = res_shapes[d] == op_shapes[current];
// either add edge between axis or just add a node in the graph
if(!is_skewed && same_shape)
graph_.add_edge({i, d}, {op, current++});
else
graph_.add_edge({i, d}, {i, d});
// reshaping is skewed
if(res_shapes[d] > 1 && !same_shape)
is_skewed = true;
}
}
void axes::update_graph_trans(ir::instruction *i) {
auto *trans = static_cast<ir::trans_inst*>(i);
ir::value *op = trans->get_operand(0);
auto perm = trans->get_perm();
// add edge between axis perm[d] and axis d
for(unsigned d = 0; d < perm.size(); d++)
graph_.add_edge({i, perm[d]}, {op, d});
}
void axes::update_graph_broadcast(ir::instruction *i) {
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
auto shapes = broadcast->get_type()->get_tile_shapes();
ir::value *op = broadcast->get_operand(0);
ir::type *op_ty = op->get_type();
const auto& op_shapes = op_ty->get_tile_shapes();
// add edge between non-broadcast axes
for(unsigned d = 0; d < shapes.size(); d ++)
if(op_shapes[d] == shapes[d])
graph_.add_edge({i, d}, {op, d});
}
void axes::update_graph_dot(ir::instruction *i) {
auto *dot = static_cast<ir::dot_inst*>(i);
auto shapes = dot->get_type()->get_tile_shapes();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
// add edges between result and accumulator
for(unsigned d = 0; d < shapes.size(); d++)
graph_.add_edge({dot, d}, {D, d});
}
void axes::update_graph_elementwise(ir::instruction *i) {
if(i->get_num_operands() == 0)
return;
ir::value *op = i->get_operand(0);
if(!op->get_type()->is_tile_ty())
return;
auto rank = op->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
if(!i->get_type()->is_void_ty())
graph_.add_edge({i, d}, {opx, d});
graph_.add_edge({opx, d}, {opy, d});
}
}
void axes::update_graph_no_edge(ir::instruction *i) {
if(!i->get_type()->is_tile_ty())
return;
auto rank = i->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)
graph_.add_edge({i, d}, {i, d});
}
void axes::update_graph(ir::instruction *i) {
switch (i->get_id()) {
case ir::INST_REDUCE: return update_graph_reduce(i);
case ir::INST_RESHAPE: return update_graph_reshape(i);
case ir::INST_SPLAT: return update_graph_no_edge(i);;
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
default: return update_graph_elementwise(i);
}
return;
}
int axes::get(ir::value *value, unsigned dim) {
return axes_.at({value, dim});
}
std::vector<int> axes::get(ir::value *value) {
std::vector<int> result;
for(size_t d = 0; d < value->get_type()->get_tile_rank(); d++)
result.push_back(this->get(value, d));
return result;
}
void axes::run(ir::module &mod) {
// make graph
graph_.clear();
ir::for_each_instruction(mod, [this](ir::instruction *x) {
update_graph(x);
});
// find connected components
graph_.connected_components(nullptr, &axes_);
}
}
}
}

View File

@@ -0,0 +1,442 @@
#include <algorithm>
#include <numeric>
#include <iostream>
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace analysis{
/* -------------------------------- *
* Helper Functions *
* -------------------------------- */
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
return std::min(std::max(x, lo), hi);
}
inline bool is_hmma_c(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = a_ty->get_scalar_ty()->is_half_ty() &&
b_ty->get_scalar_ty()->is_half_ty();
}
return result;
}
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(v);
}
}
inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && i->get_operand(n) == v)
result = v;
}
}
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v)
result = v;
}
}
inline bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
/* -------------------------------- *
* Layout Visitor *
* -------------------------------- */
void layout_visitor::visit_layout(data_layout *layout) {
layout->accept(this);
}
/* -------------------------------- *
* Base Data Layout *
* -------------------------------- */
data_layout::data_layout(id_t id,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
// io pointer
std::set<ir::value*> ptr;
for(ir::value* v: values_)
extract_io_use(v, ptr);
order_.resize(axes_.size());
std::iota(order_.begin(), order_.end(), 0);
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
});
if(*largest){
auto max_contiguous = align->contiguous(*largest);
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b];
});
}
}
size_t data_layout::find_axis(int to_find) const {
auto it = std::find(axes_.begin(), axes_.end(), to_find);
return std::distance(axes_.begin(), it);
}
/* -------------------------------- *
* MMA Layout *
* -------------------------------- */
mma884_layout::mma884_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
fpw_ = {1, 1, 1};
std::vector<int> fpw_nm1;
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
do {
fpw_nm1 = fpw_;
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
if(fpw_[0]*fpw_[1] < num_fragments)
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
}while(fpw_nm1 != fpw_);
/* warps per tile */
// try to make things as square as possible to maximize data re-use
wpt_ = {1, 1, 1};
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt_;
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
}while(wpt_nm1 != wpt_);
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shape.size(); d++)
effective_num_warps *= wpt_[d];
if(num_warps != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
/* -------------------------------- *
* Scanline Layout *
* -------------------------------- */
scanline_layout::scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
unsigned num_threads = num_warps * 32;
nts_.resize(shape_.size());
mts_.resize(shape_.size());
bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
ir::value *ptr = nullptr;
for(ir::value *v: values)
for(ir::user *usr: v->get_users())
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
ptr = st->get_pointer_operand();
unsigned i = order_[0];
int contiguous = 4;
if(ptr)
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
size /= shape_[i];
num_threads /= mts_[i];
if(is_dot)
nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
for(size_t d = 1; d < shape_.size(); d++){
i = order_[d];
if(d > 1 || !is_dot)
nts_[i] = 1;
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts_[i];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shape_.size(); d++)
effective_num_threads *= mts_[d];
if(num_warps * 32 != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
/* -------------------------------- *
* Shared Layout *
* -------------------------------- */
bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
return false;
else
throw std::runtime_error("unreachable");
}
void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
auto* phi = dynamic_cast<ir::phi_node*>(v);
if(!phi || phi->get_num_incoming() != 2)
return;
ir::basic_block *block_0 = phi->get_incoming_block(0);
ir::basic_block *block_1 = phi->get_incoming_block(1);
ir::instruction *terminator_0 = block_0->get_inst_list().back();
ir::instruction *terminator_1 = block_1->get_inst_list().back();
bool is_latch_0 = is_loop_latch(phi, terminator_0);
bool is_latch_1 = is_loop_latch(phi, terminator_1);
ir::value *value_0 = phi->get_incoming_value(0);
ir::value *value_1 = phi->get_incoming_value(1);
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!i_0 || !i_1 ||
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) ||
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) )
return;
if(is_latch_1)
res.reset(new double_buffer_info_t{value_0, value_1, phi});
if(is_latch_0)
res.reset(new double_buffer_info_t{value_1, value_0, phi});
}
shared_layout::shared_layout(const data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
size_ = 0;
// double-buffering
for(ir::value *v: values)
extract_double_bufferable(v, double_buffer_);
// order
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
order_ = arg_order;
ir::value* dot_a = nullptr;
ir::value* dot_b = nullptr;
ir::value* hmma_dot_a = nullptr;
ir::value* hmma_dot_b = nullptr;
for(ir::value* v: values){
extract_dot_use(v, dot_a, 0);
extract_dot_use(v, dot_b, 1);
extract_hmma_dot_use(v, hmma_dot_a, 0);
extract_hmma_dot_use(v, hmma_dot_b, 1);
}
// non-mma ordering
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
for(size_t s = 2; s < get_rank(); s++){
col.push_back(s);
row.push_back(s);
}
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)
order_ = is_trans(dot_a) ? row : col;
else if(is_nonhmma_dot_b)
order_ = is_trans(dot_b) ? col : row;
// padding
size_t pad = 0;
if(hmma_dot_a){
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
pad = 24 - shape_[row ? 0 : 1] % 32;
}
else if(hmma_dot_b){
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
pad = 24 - shape_[row ? 1 : 0] % 32;
}
else if(order_ != arg_order) {
pad = 4;
}
shape_[order_[0]] += pad;
// size
size_ = ty_->get_primitive_size_in_bits() / 8;
for(auto s: shape_)
size_ *= s;
if(double_buffer_)
size_ *= 2;
}
/* -------------------------------- *
* ---- Layouts Inference Pass ---- *
* -------------------------------- */
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps)
: axes_(axes), align_(align), num_warps_(num_warps) { }
void layouts::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_tile_ty())
return;
if(!y->get_type()->is_tile_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
std::set<int> common;
std::set_intersection(sx_axes.begin(), sx_axes.end(),
sy_axes.begin(), sy_axes.end(),
std::inserter(common, common.begin()));
graph_.add_edge(x, x);
graph_.add_edge(y, y);
if(!common.empty())
graph_.add_edge(x, y);
}
void layouts::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
auto cmp = [](ir::value* x, ir::value *y) {
return x->get_type()->get_tile_ranks1() <
y->get_type()->get_tile_ranks1();
};
std::vector<ir::value*> lvalue = values;
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_tile_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
return dynamic_cast<ir::copy_to_shared_inst*>(v);
});
// type
if(it_hmma_c != values.end())
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
else if(it_cts != values.end()){
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
}
else
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_);
}
void layouts::run(ir::module &mod) {
// make graph
graph_.clear();
ir::for_each_instruction(mod, [this](ir::instruction* i) {
make_graph(i);
});
// connected components
graph_.connected_components(&values_, &groups_);
// create layouts
for(const auto& x: values_)
create(x.first, x.second);
// create temporaries
size_t id = values_.size();
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
id++;
ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis();
// shape
auto shapes = arg->get_type()->get_tile_shapes();
unsigned shape_ax = shapes[axis];
scanline_layout *layout = get(arg)->to_scanline();
unsigned per_thread = layout->nts(axis);
unsigned depth = shape_ax / per_thread;
shapes[axis] = depth;
// create layout
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
tmp_[red] = id;
}
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
ir::value *val = recoalasce->get_operand(0);
mma884_layout* in_layout = get(val)->to_mma884();
scanline_layout* out_layout = get(i)->to_scanline();
if(!in_layout || !out_layout)
return;
id++;
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
ir::type::tile_shapes_t shape(in_shape.size());
size_t ld = out_layout->get_order(0);
shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++)
if(k != ld)
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
// create layout
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
tmp_[recoalasce] = id;
}
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
tmp_[atom] = id;
}
});
}
}
}
}

View File

@@ -0,0 +1,57 @@
#include <climits>
#include <iostream>
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace analysis{
void liveness::run(ir::module &mod) {
intervals_.clear();
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
for(ir::function *fn: mod.get_function_list()){
slot_index index = 0;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()){
index += 1;
indices.insert({instr, index});
}
}
// create live intervals
for(auto &x: layouts_->get_all()) {
shared_layout* layout = x.second->to_shared();
if(!layout)
continue;
// users
std::set<ir::user*> users;
for(ir::value *v: layout->get_values()){
for(ir::user *u: v->get_users())
users.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
for(ir::value *v: layout->get_values())
if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v));
unsigned end = 0;
for(ir::user *u: users)
if(indices.find(u) != indices.end())
end = std::max(end, indices.at(u));
intervals_[layout] = segment{start, end};
}
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,326 @@
#include <numeric>
#include "triton/codegen/selection/machine_layout.h"
#include "triton/codegen/selection/machine_value.h"
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/target.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "llvm/IR/IRBuilder.h"
namespace triton{
namespace codegen{
using namespace llvm;
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
// function
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
std::vector<Type*> param_tys;
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
return FunctionType::get(return_ty, param_tys, false);
}
// pointer
if(ty->is_pointer_ty()){
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
unsigned addr_space = ty->get_pointer_address_space();
return PointerType::get(elt_ty, addr_space);
}
// integer
if(ty->is_integer_ty()){
unsigned bitwidth = ty->get_integer_bitwidth();
return IntegerType::get(ctx, bitwidth);
}
// primitive types
switch(ty->get_type_id()){
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
default: break;
}
// unknown type
throw std::runtime_error("unknown conversion from ir::type to Type");
}
// Grid construction
inline std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
size_t dim = shapes.size();
std::vector<Value*> result(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder.getInt32(shapes[order[k]]);
Value *rem = builder.CreateURem(trailing, dim_k);
trailing = builder.CreateUDiv(trailing, dim_k);
result[order[k]] = rem;
}
result[order[dim - 1]] = trailing;
return result;
}
inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
Value *&sh_mem_ptr, analysis::shared_layout *layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap)
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
// double-buffered
if(layout_->get_double_buffer()) {
BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout_->get_double_buffer();
ir::phi_node *phi = info.phi;
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
if(parent->empty())
builder_->SetInsertPoint(parent);
else
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
// create pointers
ptr_ = builder_->CreatePHI(ptr_ty, 2);
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
builder_->SetInsertPoint(current);
}
else{
size_t offset = alloc_->offset(layout_);
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
}
}
tile* machine_shared_layout::create(ir::value *v) {
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
auto double_buffer = layout_->get_double_buffer();
// offset
Value *offset = nullptr;
if(double_buffer && v == double_buffer->phi)
offset = offset_;
// base pointer
Value *ptr = ptr_;
if(double_buffer && v == double_buffer->latch)
ptr = next_ptr_;
else if(double_buffer && v == double_buffer->first)
ptr = pre_ptr_;
// create tile
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
}
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::data_layout *layout)
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
}
tile *machine_distributed_layout::create(ir::value *v) {
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
const auto &shapes = v->get_type()->get_tile_shapes();
size_t rank = shapes.size();
std::vector<distributed_axis> axes(rank);
std::vector<int> order(rank);
// compute axes
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
axes[d] = axes_.at(x);
}
else{
axes[d].contiguous = 1;
axes[d].values = {builder_->getInt32(0)};
}
}
// compute order
std::iota(order.begin(), order.end(), 0);
auto cmp = [&](int x, int y) {
unsigned axx = a_axes_->get(v, x);
unsigned axy = a_axes_->get(v, y);
size_t posx = layout_->find_axis(axx);
size_t posy = layout_->find_axis(axy);
if(posx < rank && posy < rank)
return layout_->get_order(posx) < layout_->get_order(posy);
return false;
};
std::sort(order.begin(), order.end(), cmp);
return new distributed_tile(ty, shapes, order, axes, *builder_);
}
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
target *tgt, analysis::axes *a_axes,
std::map<unsigned, distributed_axis>& axes,
analysis::mma884_layout* layout)
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
const auto& shape = layout->get_shape();
if(shape.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shape.size() >= 3;
Value *_1 = builder_->getInt32(1);
Value *_2 = builder_->getInt32(2);
Value *_3 = builder_->getInt32(3);
Value *_4 = builder_->getInt32(4);
Value *_16 = builder_->getInt32(16);
// fragments per warp
unsigned fpw_0 = layout->fpw(0);
unsigned fpw_1 = layout->fpw(1);
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
// warps per tile
unsigned wpt_0 = layout->wpt(0);
unsigned wpt_1 = layout->wpt(1);
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
// mma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
// mma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shape[0] / hmma_bts_0;
unsigned num_rep_1 = shape[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;
/* intra warp offset */
// offset of quad in pair
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_1 * pack_size_1_));
// Quad pair id
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
// Quad pair offset
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
/* inter warp offset */
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
/* offsets */
// a offset
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
// b offsets
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
// c offsets
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
builder_->CreateAdd(warp_offset_j, pair_b_off));
/* indices */
// i indices
std::vector<Value*> idx_i;
for(unsigned pack = 0; pack < num_packs_0_; pack++)
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned i = 0; i < 2; i++){
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
}
// j indices
std::vector<Value*> idx_j;
for(unsigned pack = 0; pack < num_packs_1_; pack++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
// z indices
std::vector<Value*> idx_z;
for(unsigned pack = 0; pack < num_rep_2; pack++)
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
/* axes */
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
}
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
analysis::scanline_layout* layout)
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
auto order = layout->get_order();
const auto& shape = layout->get_shape();
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
// Delinearize
size_t dim = shape.size();
std::vector<Value*> thread_id(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
thread_id[order[k]] = rem;
}
thread_id[order[dim - 1]] = full_thread_id;
// Create axes
for(unsigned k = 0; k < dim; k++) {
int nts = layout->nts(k);
int mts = layout->mts(k);
std::string str_k = std::to_string(k);
Value *contiguous_k = builder_->getInt32(nts);
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts * mts;
unsigned per_thread = nts * shape[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts * per_block + n % nts;
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
}
}
}
}

View File

@@ -0,0 +1,214 @@
#include <numeric>
#include <iostream>
#include "llvm/IR/IRBuilder.h"
#include "triton/codegen/selection/machine_value.h"
namespace triton{
namespace codegen{
using namespace llvm;
/* Distributed Tile */
void distributed_tile::init_indices() {
std::vector<size_t> id(axes_.size(), 0);
// build
size_t k = 0;
while(true) {
indices_t current;
for(size_t d = 0; d < id.size(); d++)
current.push_back(axes_[d].values[id[d]]);
size_t sz = indices_.size();
indices_[current] = sz;
values_[current] = nullptr;
ordered_indices_.push_back(current);
id[order_[0]]++;
while(id[order_[k]] == axes_[order_[k]].values.size()){
if(k == id.size() - 1)
return;
id[order_[k++]] = 0;
id[order_[k]]++;
}
k = 0;
}
}
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
init_indices();
}
void distributed_tile::set_value(indices_t idx, Value *x) {
assert(x->getType() == ty_ && "cannot set a value of different type");
Value *&result = values_[idx];
assert(!result && "value cannot be set twice");
result = x;
}
Value* distributed_tile::get_value(indices_t idx) {
Value *result = values_.at(idx);
assert(result && "value has not been set");
return result;
}
unsigned distributed_tile::get_linear_index(indices_t idx) {
return indices_[idx];
}
indices_t distributed_tile::get_ordered_indices(unsigned id) {
return ordered_indices_.at(id);
}
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
if(end < 0)
end = ordered_indices_.size() + end + 1;
for(unsigned i = start; i < end; i++)
fn(ordered_indices_[i]);
}
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
int rank = sizes.size();
int len = 1;
for(int s: sizes)
len *= s;
for(int i = 0; i < len; i++){
indices_t idx(rank);
int current = i;
for(int k = 0; k < rank; k++){
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
current = current / sizes[k];
}
fn(idx);
}
}
/* Shared Tile */
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
if(dyn_cast<Constant>(arg)){
cst = arg;
non_cst = _0;
return;
}
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
non_cst = arg;
cst = _0;
return;
}
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
if(cst_lhs && cst_rhs){
cst = arg;
non_cst = _0;
}
else if(cst_lhs){
cst = cst_lhs;
non_cst = bin_op->getOperand(1);
}
else if(cst_rhs){
cst = cst_rhs;
non_cst = bin_op->getOperand(0);
}
else{
non_cst = arg;
cst = _0;
}
}
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
non_cst_idx.clear();
cst_idx.clear();
for(Value *idx: arg_idx){
Value *non_cst, *cst;
extract_constant(idx, non_cst, cst);
non_cst_idx.push_back(non_cst);
cst_idx.push_back(cst);
}
}
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
const std::vector<int>& perm, const std::vector<int>& order,
indices_t idx) {
// strides
std::vector<Value*> strides(order.size());
strides[order[0]] = builder.getInt32(1);
for(size_t i = 1; i < idx.size(); i++)
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
// result
Value *result = builder.getInt32(0);
for(size_t i = 0; i < strides.size(); i++)
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
return result;
}
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
return_vector_ = false;
if(perm_.empty()){
perm_.resize(shapes.size());
std::iota(perm_.begin(), perm_.end(), 0);
}
}
void shared_tile::set_value(indices_t idx, Value *value) {
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
builder_.CreateStore(value, ptr);
}
void shared_tile::set_vector_size(unsigned vector_size) {
vector_size_ = vector_size;
}
void shared_tile::set_return_mode(bool return_vector){
return_vector_ = return_vector;
}
Value* shared_tile::get_value(indices_t idx) {
indices_t non_cst_idx, cst_idx;
extract_constant(idx, non_cst_idx, cst_idx);
Value *&base_ptr = ptr_cache_[non_cst_idx];
unsigned vector_size = vector_size_;
Type *ty = ty_;
if(ty->isHalfTy() && (vector_size % 2 == 0)){
ty = IntegerType::get(ty->getContext(), 32);
vector_size = vector_size / 2;
}
if(base_ptr == nullptr){
// BasicBlock* store = builder_.GetInsertBlock();
// if(!non_cst_idx.empty())
// if(isa<Instruction>(non_cst_idx.front())){
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
// }
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
if(vector_size_ > 1){
Type *vec_ty = VectorType::get(ty, vector_size);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
}
// builder_.SetInsertPoint(store);
}
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
Value *div = offset;
if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
Value *ptr = builder_.CreateGEP(base_ptr, div);
Value *result = builder_.CreateLoad(ptr);
if(return_vector_ == false && vector_size_ > 1) {
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
result = builder_.CreateExtractElement(result, rem);
}
return result;
}
}
}

174
lib/codegen/target.cc Normal file
View File

@@ -0,0 +1,174 @@
#include "triton/codegen/target.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/IRBuilder.h"
#include <iostream>
using namespace llvm;
namespace triton{
namespace codegen{
// base
bool target::is_gpu() const {
return is_gpu_;
}
// AMD
void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
fn->setCallingConv(CallingConv::AMDGPU_KERNEL);
}
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
return builder.CreateCall(barrier, {});
}
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
throw std::runtime_error("not implemented");
}
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workgroup_id_x,
Intrinsic::amdgcn_workgroup_id_y,
Intrinsic::amdgcn_workgroup_id_z
};
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
Value* group_id = builder.CreateCall(get_group_id, {});
return group_id;
}
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::r600_read_ngroups_x,
Intrinsic::r600_read_ngroups_y,
Intrinsic::r600_read_ngroups_z
};
Value* get_num_group = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_num_group, {});
}
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workitem_id_x,
Intrinsic::amdgcn_workitem_id_y,
Intrinsic::amdgcn_workitem_id_z
};
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_local_id, {});
}
// NVIDIA
void nvidia_cu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn){
// set metadata
Metadata *md_args[] = {
ValueAsMetadata::get(fn),
MDString::get(ctx, "kernel"),
ValueAsMetadata::get(builder.getInt32(1))
};
module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
}
Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_barrier0);
return builder.CreateCall(barrier, {});
}
Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl);
return builder.CreateCall(barrier, {});
}
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> cta_ids = {
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
};
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
Value* cta_id = builder.CreateCall(get_cta_id, {});
return cta_id;
}
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::nvvm_read_ptx_sreg_tid_x,
Intrinsic::nvvm_read_ptx_sreg_tid_y,
Intrinsic::nvvm_read_ptx_sreg_tid_z
};
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_local_id, {});
}
Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
};
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_nctaid, {});
}
// CPU
void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
// normal cpu functions can be kernels
}
Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
// no barrier on CPU
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) {
// no barrier on CPU
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
const Function *fn = builder.GetInsertBlock()->getParent();
size_t num_params = fn->getFunctionType()->getNumParams();
static std::array<const Argument*, 3> ids = {
fn->arg_begin() + num_params - 3,
fn->arg_begin() + num_params - 2,
fn->arg_begin() + num_params - 1
};
return (Argument*)ids[ax];
}
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented");
}
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
return result;
}
Value* cpu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
return builder.getInt32(0);
}
}
}

View File

@@ -0,0 +1,143 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h"
namespace triton {
namespace codegen{
namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
: align_(align), layout_(layouts) { }
// Find all values that are used as pointer operands in LD/ST
void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(i);
}
}
void coalesce::extract_ld(ir::io_inst* i, std::map<int, std::vector<ir::io_inst*>>& result) {
ir::value *ptr = i->get_pointer_operand();
auto contiguous = align_->contiguous(ptr);
auto it = std::max_element(contiguous.begin(), contiguous.end());
int axis = std::distance(contiguous.begin(), it);
result[axis].push_back(i);
}
ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
std::map<ir::value*, ir::value*>& seen) {
if(seen.find(x) != seen.end())
return seen.at(x);
auto i = dynamic_cast<ir::instruction*>(x);
// not an instruction -- forward value
if(!i)
return x;
// already in shared memory -- forward value
if(dynamic_cast<ir::copy_to_shared_inst*>(x)){
return x;
}
// set insert point
auto& inst_list = i->get_parent()->get_inst_list();
auto pos = ++std::find(inst_list.begin(), inst_list.end(), i);
builder.set_insert_point(pos);
if(dynamic_cast<ir::load_inst*>(x)){
ir::value *ret = builder.insert(ir::copy_to_shared_inst::create(x));
return ret;
}
// default -- recursive clone
ir::instruction *cloned = builder.insert(i->clone());
seen[i] = cloned;
// rematerialize operands
for(ir::value *op: cloned->ops())
cloned->replace_uses_of_with(op, rematerialize(op, builder, seen));
return cloned;
}
void coalesce::run(ir::module &mod) {
size_t num_groups = layout_->num_layouts();
for(size_t id = 0; id < num_groups; id++) {
if(!layout_->get(id)->to_mma884())
continue;
// extract memory stores
const auto& values = layout_->values_of(id);
ir::value* dot = nullptr;
for(ir::value *v: values)
if(auto x = dynamic_cast<ir::dot_inst*>(v))
dot = x;
ir::builder& builder = mod.get_builder();
std::vector<ir::value*> worklist = {dot};
std::set<ir::value*> seen;
while(!worklist.empty()) {
ir::value *current = worklist.back();
seen.insert(current);
worklist.pop_back();
// stop if trunc
if(auto x = dynamic_cast<ir::fp_trunc_inst*>(current)){
builder.set_insert_point_after(x);
ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x);
builder.insert(rc);
x->replace_all_uses_with(rc);
rc->replace_uses_of_with(rc, x);
break;
}
// recurse
for(ir::user *u: current->get_users())
if(seen.find(u) == seen.end())
worklist.push_back(u);
}
}
// find values to rematerialize
std::vector<ir::io_inst*> remat;
for(size_t id = 0; id < num_groups; id++) {
const auto& values = layout_->values_of(id);
// extract pointers used in ld/st operations
std::set<ir::io_inst*> io;
for(ir::value *v: values)
extract_io_use(v, io);
// extract leading axes
std::map<int, std::vector<ir::io_inst*>> axes;
for(ir::io_inst *i: io){
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->get_rank())
extract_ld(i, axes);
}
// update list of values to rematerialize
if(axes.empty())
continue;
for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
remat.insert(remat.begin(), it->second.begin(), it->second.end());
}
// rematerialize values
for(ir::io_inst *r: remat) {
ir::builder& builder = mod.get_builder();
// rematerialize operands
std::map<ir::value*, ir::value*> seen;
for(ir::value *op: r->ops())
r->replace_uses_of_with(op, rematerialize(op, mod.get_builder(), seen));
// copy to shared if load
auto& inst_list = r->get_parent()->get_inst_list();
auto pos = ++std::find(inst_list.begin(), inst_list.end(), r);
builder.set_insert_point(pos);
if(dynamic_cast<ir::load_inst*>(r)){
ir::instruction *cts = builder.insert(ir::copy_to_shared_inst::create(r));
r->replace_all_uses_with(cts);
cts->replace_uses_of_with(cts, r);
}
}
}
}
}
}

View File

@@ -0,0 +1,95 @@
#include "triton/codegen/transform/cts.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace transform{
inline bool is_shmem_op(ir::instruction* i, int op) {
if(i->get_id() == ir::INST_DOT)
return op==0 || op==1;
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
return op==0;
if(i->get_id() == ir::INST_TRANS)
return op==0;
return false;
}
inline bool is_shmem_res(ir::value* v){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return false;
if(i->get_id() == ir::INST_TRANS)
return true;
if(i->get_id() == ir::INST_REDUCE)
return true;
if(i->get_id() == ir::INST_COPY_TO_SHARED)
return true;
return false;
}
// run pass on module
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction
if(!i) {
builder.set_insert_point(parent);
ir::value *copy;
if(to_shared)
copy = builder.create_copy_to_shared(x);
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
return;
}
// phi node
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
return;
}
// already in shared memory
if(to_shared && is_shmem_res(i))
return;
// copy
builder.set_insert_point_after(i);
ir::value *copy;
if(to_shared)
copy = builder.create_copy_to_shared(x);
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
}
void cts::run(ir::module &mod) {
// Add shared copies
ir::builder &builder = mod.get_builder();
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
size_t num_op = i->get_num_operands();
// copy to shared operands
for(size_t k = 0; k < num_op; k++)
if(is_shmem_op(i, k))
add_copy(i, i->get_operand(k), builder, true);
// copy from shared operands
for(size_t k = 0; k < num_op; k++)
if(!dynamic_cast<ir::phi_node*>(i) &&
!is_shmem_op(i,k) &&
is_shmem_res(i->get_operand(k))){
add_copy(i, i->get_operand(k), builder, false);
}
}
}
}
}
}
}

View File

@@ -0,0 +1,75 @@
#include "triton/codegen/transform/dce.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
void dce::run(ir::module &mod) {
std::list<ir::instruction*> work_list;
std::set<ir::instruction*> marked;
// initialize work-list
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// iterate through blocks
for(ir::basic_block *block: rpo)
for(ir::instruction *i: block->get_inst_list()){
switch(i->get_id()){
case ir::INST_RETURN:
case ir::INST_UNCOND_BRANCH:
case ir::INST_COND_BRANCH:
case ir::INST_UNMASKED_STORE:
case ir::INST_MASKED_STORE:
case ir::INST_ATOMIC_ADD:
case ir::INST_ATOMIC_CAS:
case ir::INST_ATOMIC_EXCH:
case ir::INST_BARRIER: {
work_list.push_back(i);
marked.insert(i);
break;
}
default:
break;
}
}
}
// mark -- ignore branches
while(!work_list.empty()){
ir::instruction* current = work_list.back();
work_list.pop_back();
// mark instruction operands
for(ir::value* op: current->ops()) {
if(auto *i = dynamic_cast<ir::instruction*>(op)){
if(marked.insert(i).second)
work_list.push_back(i);
}
}
// TODO: mark last intstruction of current's reverse-dominance frontier
}
// sweep -- delete non-branch unmarked instructions
std::vector<ir::instruction*> to_delete;
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// iterate through blocks
for(ir::basic_block *block: rpo)
for(ir::instruction *i: block->get_inst_list()){
if(marked.find(i) == marked.end())
to_delete.push_back(i);
}
}
// delete
for(ir::instruction* i: to_delete)
i->erase_from_parent();
}
}
}
}

View File

@@ -0,0 +1,76 @@
#include "triton/codegen/transform/disassociate.h"
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/builder.h"
#include "triton/ir/module.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace transform{
void extract_retile_chain(ir::user *root,
std::map<int, std::set<ir::user*>>& result,
int depth,
std::set<ir::value*>& seen) {
if(!seen.insert(root).second)
return;
result[depth].insert(root);
if(dynamic_cast<ir::make_range*>(root) ||
dynamic_cast<ir::splat_inst*>(root)){
return;
}
for(ir::value *op: root->ops()){
ir::user *u = dynamic_cast<ir::user*>(op);
if(!u)
continue;
extract_retile_chain(u, result, depth + 1, seen);
}
}
void disassociate::run(ir::module &mod) {
ir::builder &bld = mod.get_builder();
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(dynamic_cast<ir::reshape_inst*>(i)){
std::map<int, std::set<ir::user*>> chains;
std::set<ir::value*> seen;
if(!dynamic_cast<ir::user*>(i->get_operand(0)))
return;
extract_retile_chain(i, chains, 0, seen);
if(chains.size())
clone_info[i] = chains;
}
});
for(const auto& x: clone_info){
int depth = 1;
std::map<ir::instruction*, ir::instruction*> clone_map;
while(x.second.find(depth) != x.second.end()){
// clone all users
const auto& remat = x.second.at(depth);
for(ir::user* u: remat){
ir::instruction *y = (ir::instruction*)u;
ir::instruction *cloned = y->clone();
bld.set_insert_point(y);
bld.insert(cloned);
clone_map[y] = cloned;
// replace operands of parents
if(depth > 1)
for(ir::user* ux: x.second.at(depth - 1))
clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned);
else
x.first->replace_uses_of_with(y, cloned);
}
depth += 1;
}
}
}
}
}
}

View File

@@ -0,0 +1,168 @@
#include <vector>
#include <set>
#include <algorithm>
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/transform/membar.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
bool membar::intersect(const interval_vec_t &X, interval_t x) {
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
bool left_intersect = y.first <= x.first && x.first < y.second;
bool right_intersect = y.first <= x.second && x.second < y.second;
return left_intersect || right_intersect;
});
}
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
return intersect(X, y);
});
}
void membar::add_reference(ir::value *v, interval_vec_t &res){
auto *i = dynamic_cast<ir::instruction*>(v);
if(!i)
return;
if(!i->get_type()->is_tile_ty())
return;
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
if(!layout)
return;
if(alloc_->has_offset(layout)){
unsigned offset = alloc_->offset(layout);
res.push_back(interval_t(offset, offset + layout->get_size()));
}
}
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
for(ir::value *op: i->ops())
add_reference(op, res);
}
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
add_reference(i, res);
}
void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
std::set<ir::value*> incoming;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
assert(inc_val);
if(incoming.insert(inc_val).second){
ir::basic_block *block = inc_val->get_parent();
builder.set_insert_point(block->get_inst_list().back());
builder.create_barrier();
}
}
}
else {
builder.set_insert_point(instr);
builder.create_barrier();
}
}
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
membar::interval_vec_t result;
for(auto x: intervals)
for(interval_t i: x)
result.push_back(i);
return result;
}
std::pair<membar::interval_vec_t,
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
const interval_vec_t &written_to,
const interval_vec_t &read_from,
std::set<ir::instruction*>& insert_loc,
std::set<ir::value*>& safe_war) {
ir::basic_block::inst_list_t instructions = block->get_inst_list();
interval_vec_t new_written_to = written_to;
interval_vec_t new_read_from = read_from;
for(ir::instruction *i: instructions){
interval_vec_t read, written;
get_read_intervals(i, read);
get_written_intervals(i, written);
bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written);
// double buffering
if(safe_war.find(i) != safe_war.end()){
write_after_read = false;
read_after_write = false;
}
// record hazards
if(read_after_write || write_after_read) {
insert_loc.insert(i);
new_written_to.clear();
new_read_from.clear();
}
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
}
return std::make_pair(new_written_to, new_read_from);
}
void membar::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// extract phi-node associates with double-buffered
// shared-memory copies. These can be read from and written to
// without needing synchronization
std::set<ir::value*> safe_war;
for(const auto& x: layouts_->get_all()){
analysis::shared_layout* layout = x.second->to_shared();
if(!layout || !layout->get_double_buffer())
continue;
for(ir::value *v: layout->get_values())
if(v != layout->get_double_buffer()->phi)
safe_war.insert(v);
}
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, interval_vec_t> written_to;
std::map<ir::basic_block*, interval_vec_t> read_from;
std::set<ir::instruction*> insert_locs;
size_t n_inserted_im1 = 0;
bool done = false;
do{
// find barrier location
for(ir::basic_block *block: rpo){
// written to
std::vector<interval_vec_t> pred_written_to;
for(ir::basic_block* pred: block->get_predecessors())
pred_written_to.push_back(written_to[pred]);
// read from
std::vector<interval_vec_t> pred_read_from;
for(ir::basic_block* pred: block->get_predecessors())
pred_read_from.push_back(read_from[pred]);
// apply transfer function
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war);
written_to[block] = result.first;
read_from[block] = result.second;
}
size_t n_inserted_i = insert_locs.size();
done = (n_inserted_im1 == n_inserted_i);
n_inserted_im1 = n_inserted_i;
}while(!done);
for(ir::instruction* i: insert_locs)
insert_barrier(i, builder);
}
}
}
}
}

View File

@@ -0,0 +1,191 @@
#include <algorithm>
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/codegen/transform/peephole.h"
namespace triton {
namespace codegen{
namespace transform{
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
const std::vector<int>& perm) {
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
// transpose operands
std::vector<ir::value*> incs;
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm));
// create phi for transposed values
builder.set_insert_point(phi);
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
result->add_incoming(incs[n], phi->get_incoming_block(n));
return result;
}
else if(auto i = dynamic_cast<ir::instruction*>(value)){
ir::basic_block* block = i->get_parent();
auto it = std::find(block->begin(), block->end(), i);
it++;
builder.set_insert_point(it);
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
trans->set_operand(0, i);
return trans;
}
return nullptr;
}
bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
auto trans = dynamic_cast<ir::trans_inst*>(value);
if(!trans)
return false;
auto users = trans->get_users();
auto ops = trans->ops();
if(users.size() > 1 || ops.size() > 1)
return false;
ir::value* op = *ops.begin();
// trans(phi) -> phi(trans(), trans()...)
auto* phi = dynamic_cast<ir::phi_node*>(op);
if(!phi)
return false;
ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm());
if(!new_phi)
return false;
trans->replace_all_uses_with(new_phi);
return true;
}
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
// dot(a, b, 0) + c -> dot(a, b, c)
auto add = dynamic_cast<ir::binary_operator*>(value);
if(add && add->get_op() == ir::binary_op_t::FAdd) {
ir::value *lhs = add->get_operand(0);
ir::value *rhs = add->get_operand(1);
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
ir::dot_inst *rhs_dot = dynamic_cast<ir::dot_inst*>(rhs);
if(!lhs_dot && !rhs_dot)
return false;
ir::dot_inst *dot = lhs_dot ? lhs_dot : rhs_dot;
ir::value *other = (dot == lhs) ? rhs : lhs;
ir::value *acc = dot->get_operand(2);
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
ir::constant_fp *_0 = nullptr;
if(splat)
_0 = dynamic_cast<ir::constant_fp*>(splat->get_operand(0));
if(!(_0 && _0->get_value() == 0.0))
return false;
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
builder.set_insert_point(add);
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name()));
add->replace_all_uses_with(new_dot);
return true;
}
}
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
auto x = dynamic_cast<ir::reduce_inst*>(value);
if(!x)
return false;
ir::value *arg = x->get_operand(0);
auto shapes = arg->get_type()->get_tile_shapes();
if(shapes[x->get_axis()] == 1){
builder.set_insert_point(x);
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
x->replace_all_uses_with(new_red);
return true;
}
return false;
}
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
auto binop = dynamic_cast<ir::binary_operator*>(value);
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
ir::value *lhs = binop->get_operand(0);
ir::value *rhs = binop->get_operand(1);
ir::constant_int *_1_lhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs))
_1_lhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
ir::constant_int *_1_rhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs))
_1_rhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(_1_lhs){
binop->replace_all_uses_with(rhs);
return true;
}
else if(_1_rhs){
binop->replace_all_uses_with(lhs);
return true;
}
}
return false;
}
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
if(!x)
return false;
auto y = dynamic_cast<ir::getelementptr_inst*>(x->get_pointer_operand());
if(!y)
return false;
auto idx = *y->idx_begin();
auto z = dynamic_cast<ir::binary_operator*>(idx);
if(!z)
return false;
bool is_sub = z->get_op() == ir::binary_op_t::Sub;
auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
bool is_lhs_0 = lhs && (lhs->get_value()==0);
bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();
if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){
x->replace_all_uses_with(y->get_pointer_operand());
return true;
}
return false;
}
void peephole::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// keep track of whether any modification was made
std::set<ir::value*> seen;
size_t n_seen;
// rewrite dots first
do{
n_seen = seen.size();
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(seen.find(i) != seen.end())
continue;
bool was_modified = rewrite_dot(i, builder);
if(was_modified){
seen.insert(i);
}
}
}while(seen.size() != n_seen);
// rewrite other ops
seen.clear();
do{
n_seen = seen.size();
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(seen.find(i) != seen.end())
continue;
bool was_modified = false;
was_modified = was_modified || rewrite_mult(i, builder);
was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
if(was_modified)
seen.insert(i);
}
}while(seen.size() != n_seen);
}
}
}
}

View File

@@ -0,0 +1,264 @@
#include <algorithm>
#include "triton/codegen/transform/reassociate.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
ir::binary_operator *bin_op = dynamic_cast<ir::binary_operator*>(x);
bool is_bin_add = bin_op && bin_op->get_op()== ir::binary_op_t::Add;
if(is_bin_add)
return (ir::instruction*)x;
return nullptr;
}
inline bool is_cst(ir::value *x) {
if(dynamic_cast<ir::constant*>(x))
return true;
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
return is_cst(v->get_operand(0));
return false;
}
ir::value *reassociate::reassociate_idx(ir::value *old_value,
ir::builder &builder,
ir::value *&noncst,
ir::value *&cst){
// value doesn't change by default
ir::value* new_value = old_value;
cst = nullptr;
noncst = old_value;
// handle retiling
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
auto shapes = op->get_type()->get_tile_shapes();
ir::value *old_arg = op->get_operand(0);
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
// retile(x + y) = retile(x) + retile(y)
if(ir::instruction* bin_add = is_bin_add(new_arg))
if(cst){
ir::value *old_lhs = bin_add->get_operand(0);
ir::value *old_rhs = bin_add->get_operand(1);
ir::value *new_lhs = nullptr;
ir::value *new_rhs = nullptr;
if(dynamic_cast<ir::reshape_inst*>(op)){
builder.set_insert_point(op);
new_lhs = builder.create_reshape(old_lhs, shapes);
new_rhs = builder.create_reshape(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
}
if(dynamic_cast<ir::broadcast_inst*>(op)){
builder.set_insert_point(op);
new_lhs = builder.create_broadcast(old_lhs, shapes);
new_rhs = builder.create_broadcast(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
}
if(dynamic_cast<ir::splat_inst*>(op)){
builder.set_insert_point(op);
new_lhs = builder.create_splat(old_lhs, shapes);
new_rhs = builder.create_splat(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
}
}
}
// handle binary addition
if(ir::instruction* op = is_bin_add(old_value)){
builder.set_insert_point(op);
std::string name = op->get_name();
ir::value *lhs = reassociate_idx(op->get_operand (0), builder, noncst, cst);
ir::value *rhs = reassociate_idx(op->get_operand(1), builder, noncst, cst);
builder.set_insert_point(op);
// (x + y) + z
if(ir::instruction* bin_lhs = is_bin_add(lhs)){
ir::value *llhs = bin_lhs->get_operand(0);
ir::value *rlhs = bin_lhs->get_operand(1);
// (cst + x) + y -> cst + (x + y)
if(is_cst(llhs))
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs), name);
// (x + cst) + y -> cst + (x + y)
if(is_cst(rlhs))
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
}
// x + (y + z)
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
ir::value *lrhs = bin_rhs->get_operand(0);
ir::value *rrhs = bin_rhs->get_operand(1);
// x + (cst + y) -> cst + (x + y)
if(is_cst(lrhs))
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), name, cst);
// x + (y + cst) -> cst + (x + y)
if(is_cst(rrhs))
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
}
}
// extract constant and non-constant
if(ir::instruction *bin_add = is_bin_add(new_value)){
ir::value *new_lhs = bin_add->get_operand(0);
ir::value *new_rhs = bin_add->get_operand(1);
if(is_cst(new_lhs)){
cst = new_lhs;
noncst = new_rhs;
}
if(is_cst(new_rhs)){
cst = new_rhs;
noncst = new_lhs;
}
}
// clean-up if some re-ordering happened
if(old_value != new_value)
old_value->replace_all_uses_with(new_value);
return new_value;
}
/* run */
void reassociate::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// constant_range -> nv_dynamic_program_idx + nv_static_program_idx
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::make_range*> ranges;
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
for(ir::basic_block *block: rpo){
// iterate through instruction
for(ir::instruction *i: block->get_inst_list())
for(ir::value* op: i->ops())
if(auto *range = dynamic_cast<ir::make_range*>(op))
ranges.push_back(range);
}
builder.set_insert_point(rpo.front()->get_first_non_phi());
for(ir::make_range* old_range: ranges){
ir::value* dyn_range = builder.insert(ir::make_range_dyn::create(old_range->get_type()));
ir::value* static_range = ir::make_range_sta::get(old_range);
ir::value* new_range = builder.create_add(dyn_range, static_range);
old_range->replace_all_uses_with(new_range);
}
}
// reassociate
std::map<ir::value*, cst_info> infos;
std::set<ir::value*> replaced;
size_t n_replaced;
do{
n_replaced = replaced.size();
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// iterate through blocks
for(ir::basic_block *block: rpo){
// iterate through instruction
for(ir::instruction *i: block->get_inst_list()){
// retiling
if(ir::retile_inst *rt = dynamic_cast<ir::retile_inst*>(i)) {
ir::value* op = rt->get_operand(0);
if(infos.find(op) != infos.end()){
builder.set_insert_point(rt);
ir::getelementptr_inst* sta = infos.at(op).sta_ptr;
ir::value* dyn = infos.at(op).dyn_ptr;
ir::value* cst = *sta->idx_begin();
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
auto shapes = rt->get_type()->get_tile_shapes();
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
ir::value* broadcast = builder.create_broadcast(cst, shapes);
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
infos[rt] = cst_info{ndyn, nsta};
}
}
}
// getelementptr instruction
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
if(replaced.find(pz) != replaced.end())
continue;
// unpack GEP instruction
ir::value* py = pz->get_pointer_operand();
ir::value* offset = *pz->idx_begin();
// reassociate index
ir::value *sta = nullptr;
ir::value *dyn = offset;
reassociate_idx(offset, builder, dyn, sta);
if(sta){
builder.set_insert_point(pz);
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
pz->replace_all_uses_with(sta_ptr);
infos[sta_ptr].dyn_ptr = dyn_ptr;
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
replaced.insert(pz);
}
// reassociate pointer argument
if(infos.find(py) != infos.end()){
builder.set_insert_point(pz);
ir::getelementptr_inst *sta = infos[py].sta_ptr;
ir::value *dyn = infos[py].dyn_ptr;
ir::value *cst = *sta->idx_begin();
ir::value *off = *pz->idx_begin();
ir::value *pz_dyn = builder.create_gep(dyn, {off});
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
pz->replace_all_uses_with(pz_sta);
infos[pz_sta].dyn_ptr = pz_dyn;
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
replaced.insert(pz);
}
// reassociate phi-node pointer
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
// only optimize the case where py = phi pa, pz for now
std::vector<ir::value*> ops = phi->ops();
if(ops.size() != 2)
continue;
if(ops[0] != pz && ops[1] != pz)
continue;
// grab incoming
size_t idx_z = (ops[0] == pz) ? 0 : 1;
size_t idx_a = (ops[0] == pz) ? 1 : 0;
// check if pa is known to have constant offset
ir::value *vpa = phi->get_incoming_value(idx_a);
auto it_a = infos.find(vpa);
if(it_a == infos.end())
continue;
// unpack dynamically/statically offset pointer
ir::value *pa_dyn = it_a->second.dyn_ptr;
ir::getelementptr_inst *pa_sta = it_a->second.sta_ptr;
ir::value *pz = phi->get_incoming_value(idx_z);
// extract offset
ir::value *off = *pa_sta->idx_begin();
builder.set_insert_point(phi);
ir::phi_node *phi_dyn = builder.create_phi(phi->get_type(), 2);
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
// re-add the offset
ir::value *phi_sta = builder.create_gep(phi_dyn, {off}, phi->get_name() + "_sta");
phi->replace_all_uses_with(phi_sta);
// remove offset from pz
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
auto insts = x->get_parent()->get_inst_list();
auto it = std::find(insts.begin(), insts.end(), x);
it++;
builder.set_insert_point(*it);
}
ir::value *_0 = builder.get_int32(0);
if(off->get_type()->is_tile_ty())
_0 = builder.create_splat(_0, off->get_type()->get_tile_shapes());
ir::value *neg_off = builder.create_sub(_0, off);
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
infos[phi_sta].dyn_ptr = phi_dyn;
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
replaced.insert(phi);
}
}
}
}
}
}while(replaced.size() != n_replaced);
}
}
}
}