[PYTHON] Cleaned up legacy code; added simple standalone compilation API (#22)

This commit is contained in:
Philippe Tillet
2022-07-26 11:06:45 -07:00
committed by GitHub
parent 96cc6fb563
commit 3265e0df5a
84 changed files with 1382 additions and 14023 deletions

View File

@@ -1,5 +0,0 @@
file(GLOB_RECURSE CODEGEN_SRC *.cc)
add_library(TritonCodeGen
${CODEGEN_SRC}
)

View File

@@ -1,533 +0,0 @@
#include "triton/codegen/analysis/align.h"
#include "triton/ir/utils.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace analysis{
// Function for extended Euclidean Algorithm
int gcd_impl(int a, int b, int *x, int *y)
{
// Base Case
if (a == 0)
{
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b%a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b/a) * x1;
*y = x1;
return gcd;
}
int gcd(int a, int b) {
int x, y;
return gcd_impl(a, b, &x, &y);
}
inline int lcm(int a, int b) {
return (a * b) / gcd(a, b);
}
template<class T>
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
return map[i] = value;
}
/*
* is constant
*/
std::vector<unsigned> align::get_shapes(ir::value *v) {
ir::type *ty = v->get_type();
if(ty->is_block_ty())
return ty->get_block_shapes();
else
return {1};
}
std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
auto shapes = get_shapes(x);
std::vector<cst_info> result(shapes.size(), cst_info{1, 0});
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto it = is_constant_.find(inc);
if(it != is_constant_.end())
result = it->second;
}
return add_to_cache(x, result, is_constant_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto cst = populate_is_constant(inc);
for(size_t d = 0; d < cst.size(); d++)
result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
auto shapes = get_shapes(x);
ir::value* op = x->get_operand(0);
std::vector<cst_info> result;
auto op_cst = populate_is_constant(op);
for(auto d: shapes)
result.push_back(cst_info{d, op_cst[0].value});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_block_shapes();
auto op_cst = populate_is_constant(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < x_shapes.size(); d ++){
cst_info ax ;
if(x_shapes[d] == 1)
ax = {1, op_cst[current].value};
else if(!is_skewed
&& x_shapes[d] == op_shapes[current])
ax = {x_shapes[d], op_cst[current++].value};
else {
is_skewed = true;
ax = {x_shapes[d], 0};
}
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_block_shapes();
auto op_cst = populate_is_constant(op);
for(size_t d = 0; d < x_shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(cst_info{x_shapes[d], op_cst[d].value});
else
result.push_back(op_cst[d]);
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
auto max_contiguous = populate_max_contiguous(lhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax;
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
// todo might not be entirely true
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
ax = {num_constants, 0};
}
else
ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0};
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) {
auto x_shapes = get_shapes(x);
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
std::vector<cst_info> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) {
auto shapes = get_shapes(v);
std::vector<cst_info> result(shapes.size(), {1, 0});
return add_to_cache(v, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
if(is_constant_.find(v) != is_constant_.end())
return is_constant_.at(v);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_is_constant_phi(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_is_constant_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_is_constant_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_is_constant_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_is_constant_binop(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_is_constant_gep(x);
return populate_is_constant_default(v);
}
/*
* max contiguous
*/
std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result(shapes.size(), 1);
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto it = max_contiguous_.find(inc);
if(it != max_contiguous_.end())
result = it->second;
}
add_to_cache(x, result, max_contiguous_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto contiguous = populate_max_contiguous(inc);
for(size_t d = 0; d < result.size(); d++)
result[d] = std::min(result[d], contiguous[d]);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<unsigned> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({1});
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_block_shapes();
auto op_mc = populate_max_contiguous(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result.push_back(1);
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result.push_back(op_mc[current++]);
else {
is_skewed = true;
result.push_back(1);
}
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_block_shapes();
auto op_mc = populate_max_contiguous(op);
for(size_t d = 0; d < shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(1);
else
result.push_back(op_mc[d]);
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) {
auto shapes = get_shapes(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
auto lhs_max_contiguous = populate_max_contiguous(lhs);
auto rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
auto lhs_starting_multiple = populate_starting_multiple(lhs);
auto rhs_starting_multiple = populate_starting_multiple(rhs);
std::vector<unsigned> result;
for(size_t d = 0; d < shapes.size(); d++){
unsigned value = 1;
if(x->is_int_rem() && rhs_starting_multiple[d] > 0){
value = std::min(lhs_max_contiguous[d], rhs_starting_multiple[d]);
}
if(x->is_int_mult()){
unsigned lvalue = 1, rvalue = 1;
if(rhs_cst_info[d].value == 1)
lvalue = lhs_max_contiguous[d];
if(lhs_cst_info[d].value == 1)
rvalue = rhs_max_contiguous[d];
value = std::max(lvalue, rvalue);
}
if(x->is_int_add_sub()){
unsigned lvalue = 1, rvalue = 1;
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]);
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]);
value = std::max(lvalue, rvalue);
}
result.push_back(value);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) {
auto shapes = get_shapes(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
auto lhs_max_contiguous = populate_max_contiguous(lhs);
auto rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
std::vector<unsigned> result(shapes.size(), 1);
for(size_t d = 0; d < shapes.size(); d++){
unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst)
lvalue = rhs_max_contiguous[d];
if(rhs_cst_info[d].num_cst)
rvalue = lhs_max_contiguous[d];
result[d] = std::max(lvalue, rvalue);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
if(!v->get_type()->is_block_ty())
return add_to_cache(v, {1}, max_contiguous_);
auto shapes = v->get_type()->get_block_shapes();
if(dynamic_cast<ir::make_range*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
auto result = populate_max_contiguous(v->get_operand(0));
return add_to_cache(v, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
if(max_contiguous > 0)
return add_to_cache(x, {max_contiguous}, max_contiguous_);
}
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_max_contiguous_cast(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_max_contiguous_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_max_contiguous_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_max_contiguous_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_max_contiguous_binop(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_max_contiguous_gep(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_max_contiguous_phi(x);
return populate_max_contiguous_default(v);
}
/*
* starting multiple
*/
std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){
auto shapes = get_shapes(x);
auto op = populate_starting_multiple(x->get_operand(0));
std::vector<unsigned> result(shapes.size(), op[0]);
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){
auto op = populate_starting_multiple(x->get_operand(0));
auto op_shapes = get_shapes(x->get_operand(0));
auto shapes = get_shapes(x);
std::vector<unsigned> result(shapes.size(), 1);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result[d] = 1;
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result[d] = op[current++];
else {
is_skewed = true;
result[d] = 1;
}
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++){
if(x->is_int_mult())
result[d] = lhs[d] * rhs[d];
if(x->is_int_add_sub())
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div())
result[d] = 1;
if(x->is_int_rem() && rhs[d] > 1){
result[d] = gcd(lhs[d], rhs[d]);
}
if(x->is_shl())
result[d] = lhs[d] << rhs[d];
if(x->is_shr())
result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++){
result[d] = gcd(lhs[d], rhs[d]);
// std::cout << "starting multiple: " << x->get_name() << " " << d << " " << result[d] << std::endl;
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
auto shape = get_shapes(x);
std::vector<unsigned> result(shape.size(), 1);
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
if(starting_multiple_.find(inc) != starting_multiple_.end())
result = starting_multiple_.at(inc);
}
add_to_cache(x, result, starting_multiple_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto sm = populate_starting_multiple(inc);
for(size_t d = 0; d < result.size(); d++)
result[d] = gcd(result[d], sm[d]);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
ir::type* ty = v->get_type();
if(ty->is_block_ty()) {
return add_to_cache(v, ty->get_block_shapes(), starting_multiple_);
}
if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
for(auto attr: attributes){
if(attr.get_kind() == ir::multiple_of){
return add_to_cache(x, {attr.get_value()}, starting_multiple_);
}
if(attr.get_kind() == ir::aligned){
ir::type* ty = x->get_type()->get_pointer_element_ty();
int nbits = ty->get_primitive_size_in_bits();
int nbytes = std::max<int>(nbits / 8, 1);
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
}
}
}
return add_to_cache(v, {1}, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return add_to_cache(x, {multiple_of}, starting_multiple_);
}
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_starting_multiple_cast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_starting_multiple_binop(x);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range*>(v))
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_starting_multiple_gep(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_starting_multiple_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_starting_multiple_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_starting_multiple_broadcast(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_starting_multiple_phi(x);
return populate_starting_multiple_default(v);
}
unsigned align::get(ir::value *v, unsigned ax) const {
unsigned starting_multiple = starting_multiple_.at(v)[ax];
unsigned max_contiguous = max_contiguous_.at(v)[ax];
return std::min(starting_multiple, max_contiguous);
}
std::vector<unsigned> align::contiguous(ir::value* v) const {
return max_contiguous_.at(v);
}
void align::populate(ir::value *v) {
populate_is_constant(v);
populate_starting_multiple(v);
populate_max_contiguous(v);
}
void align::run(ir::module &mod) {
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
// ir::for_each_value(mod, [this](ir::value* v) {
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
// });
}
}
}
}

View File

@@ -1,101 +0,0 @@
#include <algorithm>
#include <climits>
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace analysis{
void allocation::run(ir::module &mod) {
using std::max;
using std::min;
typedef std::multimap<unsigned, segment> triples_map_type;
std::vector<shared_layout*> I;
for(auto x: liveness_->get())
I.push_back(x.first);
std::vector<shared_layout*> J = I;
triples_map_type H;
H.insert({0, segment{0, INT_MAX}});
std::vector<shared_layout*> V;
std::map<shared_layout*, unsigned> starts;
while(!J.empty()){
auto h_it = H.begin();
unsigned w = h_it->first;
segment xh = h_it->second;
H.erase(h_it);
auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){
segment xj = liveness_->get(JJ);
bool res = xj.intersect(xh);
for(auto val: H)
res = res && !val.second.intersect(xj);
return res;
});
if(j_it != J.end()){
unsigned size = (*j_it)->get_size();
segment xj = liveness_->get(*j_it);
starts[*j_it] = w;
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
if(xh.start < xj.start)
H.insert({w, segment{xh.start, xj.end}});
if(xj.end < xh.end)
H.insert({w, segment{xj.start, xh.end}});
V.push_back(*j_it);
J.erase(j_it);
}
}
// Build interference graph
std::map<shared_layout*, std::set<shared_layout*>> interferences;
for(shared_layout* x: V)
for(shared_layout* y: V){
if(x == y)
continue;
unsigned X0 = starts[x], Y0 = starts[y];
unsigned NX = x->get_size();
unsigned NY = y->get_size();
segment XS = {X0, X0 + NX};
segment YS = {Y0, Y0 + NY};
if(liveness_->get(x).intersect(liveness_->get(y))
&& XS.intersect(YS))
interferences[x].insert(y);
}
// Initialize colors
std::map<shared_layout*, int> colors;
for(shared_layout* X: V)
colors[X] = (X==V[0])?0:-1;
// First-fit graph coloring
std::vector<bool> available(V.size());
for(shared_layout* x: V){
// Non-neighboring colors are available
std::fill(available.begin(), available.end(), true);
for(shared_layout* Y: interferences[x]){
int color = colors[Y];
if(color >= 0)
available[color] = false;
}
// Assigns first available color
auto It = std::find(available.begin(), available.end(), true);
colors[x] = std::distance(available.begin(), It);
}
// Finalize allocation
for(shared_layout* x: V){
unsigned Adj = 0;
for(shared_layout* y: interferences[x])
Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
offsets_[x] = starts[x] + colors[x] * Adj;
}
// Save maximum size of induced memory space
allocated_size_ = 0;
for(shared_layout* x: V)
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
}
}
}
}

View File

@@ -1,162 +0,0 @@
#include "triton/codegen/analysis/axes.h"
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton{
namespace codegen{
namespace analysis{
axes::axes() {}
void axes::update_graph_reduce(ir::instruction *i) {
auto* red = static_cast<ir::reduce_inst*>(i);
unsigned axis = red->get_axis();
ir::value *arg = red->get_operand(0);
auto in_shapes = arg->get_type()->get_block_shapes();
unsigned current = 0;
for(unsigned d = 0; d < in_shapes.size(); d++){
if(d == axis)
continue;
graph_.add_edge({i, current++}, {arg, d});
}
}
void axes::update_graph_reshape(ir::instruction *i) {
auto* reshape = static_cast<ir::reshape_inst*>(i);
// operands
ir::value *op = reshape->get_operand(0);
// shapes
auto op_shapes = op->get_type()->get_block_shapes();
auto res_shapes = reshape->get_type()->get_block_shapes();
// construct edges
unsigned current = 0;
bool is_skewed = false;
for(unsigned d = 0; d < res_shapes.size(); d ++){
bool same_shape = res_shapes[d] == op_shapes[current];
// either add edge between axis or just add a node in the graph
if(!is_skewed && same_shape)
graph_.add_edge({i, d}, {op, current++});
else
graph_.add_edge({i, d}, {i, d});
// reshaping is skewed
if(res_shapes[d] > 1 && !same_shape)
is_skewed = true;
}
}
void axes::update_graph_trans(ir::instruction *i) {
auto *trans = static_cast<ir::trans_inst*>(i);
ir::value *op = trans->get_operand(0);
auto perm = trans->get_perm();
// add edge between axis perm[d] and axis d
for(unsigned d = 0; d < perm.size(); d++)
graph_.add_edge({i, perm[d]}, {op, d});
}
void axes::update_graph_broadcast(ir::instruction *i) {
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
auto shapes = broadcast->get_type()->get_block_shapes();
ir::value *op = broadcast->get_operand(0);
ir::type *op_ty = op->get_type();
const auto& op_shapes = op_ty->get_block_shapes();
// add edge between non-broadcast axes
for(unsigned d = 0; d < shapes.size(); d ++)
if(op_shapes[d] == shapes[d])
graph_.add_edge({i, d}, {op, d});
}
void axes::update_graph_dot(ir::instruction *i) {
auto *dot = static_cast<ir::dot_inst*>(i);
auto shapes = dot->get_type()->get_block_shapes();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
// add edges between result and accumulator
for(unsigned d = 0; d < shapes.size(); d++)
graph_.add_edge({dot, d}, {D, d});
}
void axes::update_graph_elementwise(ir::instruction *i,
bool is_masked_load_async) {
if(i->get_num_operands() == 0)
return;
ir::value *op = i->get_operand(0);
if(!op->get_type()->is_block_ty())
return;
auto rank = op->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++) {
// If we are dealing with a masked async load we need to attach the
// dimensions so we match the behaviour of the copy_to_shared instruction
// which async masked load replaces.
if (is_masked_load_async) {
graph_.add_edge({i, d}, {i, d});
}
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()) {
if(!is_masked_load_async && !i->get_type()->is_void_ty())
graph_.add_edge({i, d}, {opx, d});
graph_.add_edge({opx, d}, {opy, d});
}
}
}
void axes::update_graph_no_edge(ir::instruction *i) {
if(!i->get_type()->is_block_ty())
return;
auto rank = i->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)
graph_.add_edge({i, d}, {i, d});
}
void axes::update_graph(ir::instruction *i) {
switch (i->get_id()) {
case ir::INST_REDUCE: return update_graph_reduce(i);
case ir::INST_RESHAPE: return update_graph_reshape(i);
case ir::INST_SPLAT: return update_graph_no_edge(i);
case ir::INST_CAT: return update_graph_elementwise(i, true);
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true);
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
default: return update_graph_elementwise(i);
}
return;
}
int axes::get(ir::value *value, unsigned dim) {
return axes_.at({value, dim});
}
std::vector<int> axes::get(ir::value *value) {
std::vector<int> result;
for(size_t d = 0; d < value->get_type()->get_tile_rank(); d++)
result.push_back(this->get(value, d));
return result;
}
void axes::run(ir::module &mod) {
// make graph
graph_.clear();
axes_.clear();
ir::for_each_instruction(mod, [this](ir::instruction *x) {
update_graph(x);
});
// find connected components
graph_.connected_components(nullptr, &axes_);
std::set<size_t> uniq;
for(auto x: axes_)
uniq.insert(x.second);
}
}
}
}

View File

@@ -1,653 +0,0 @@
#include <algorithm>
#include <numeric>
#include <iostream>
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
// #include "triton/ir/type.h"
namespace triton{
namespace codegen{
namespace analysis{
/* -------------------------------- *
* Helper Functions *
* -------------------------------- */
inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
unsigned lo = std::min(a, b);
unsigned hi = std::max(a, b);
return std::min(std::max(x, lo), hi);
}
inline bool is_hmma_c(ir::value *v, int sm){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
x->allow_tf32() && sm >= 80) ||
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
sm >= 80);
}
return result;
}
static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
mma_layout::TensorCoreType mma_type;
if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
ir::type* a_ty = a->get_type();
ir::type* b_ty = b->get_type();
ir::type* c_ty = v->get_type();
if (c_ty->get_scalar_ty()->is_fp32_ty()) {
// floating point tensor cores
if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
mma_type = mma_layout::FP32_FP16_FP16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
mma_type = mma_layout::FP32_BF16_BF16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
&& dot->allow_tf32()) {
mma_type = mma_layout::FP32_TF32_TF32_FP32;
return mma_type;
}
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
// throw std::runtime_error("integer tensor cores are not yet supported");
// // integer tensor cores
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
// return mma_type;
// }
// if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
// return mma_type;
// }
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
mma_type = mma_layout::INT32_INT8_INT8_INT32;
return mma_type;
}
}
}
return mma_layout::NOT_APPLICABLE;
}
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(v);
}
}
inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && i->get_operand(n) == v)
result = v;
}
}
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
result = i;
}
}
}
inline bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
/* -------------------------------- *
* Layout Visitor *
* -------------------------------- */
void layout_visitor::visit_layout(data_layout *layout) {
layout->accept(this);
}
/* -------------------------------- *
* Base Data Layout *
* -------------------------------- */
data_layout::data_layout(id_t id,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
// io pointer
std::set<ir::value*> ptr;
for(ir::value* v: values_)
extract_io_use(v, ptr);
order_.resize(axes_.size());
std::iota(order_.begin(), order_.end(), 0);
std::vector<unsigned> max_contiguous;
for(ir::value* p: ptr){
std::vector<unsigned> curr = align->contiguous(p);
if(curr.size() > max_contiguous.size())
max_contiguous = curr;
else if(curr.size() == max_contiguous.size()){
if(*std::max_element(curr.begin(), curr.end()) > *std::max_element(max_contiguous.begin(), max_contiguous.end()))
max_contiguous = curr;
}
}
if(max_contiguous.size() > 0){
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b];
});
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// std::cout << order_[0] << " " << order_[1] << std::endl;
}
}
int data_layout::find_axis(int to_find) const {
auto it = std::find(axes_.begin(), axes_.end(), to_find);
if(it == axes_.end())
return -1;
return std::distance(axes_.begin(), it);
}
distributed_layout::distributed_layout(id_t id,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(id, axes, shape, values, align)
{ }
/* -------------------------------- *
* MMA Layout *
* -------------------------------- */
mma_layout::mma_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align, target* tgt,
shared_layout *layout_a, shared_layout *layout_b,
ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
tensor_core_type_ = get_mma_type(dot);
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
if(tgt->as_nvidia()->sm() < 80){
fpw_ = {2, 2, 1};
auto ord_a = layout_a->get_order();
auto ord_b = layout_b->get_order();
bool is_a_row = ord_a[0] != 0;
bool is_b_row = ord_b[0] != 0;
bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
contig_per_thread_ = {1, 1};
}
else{
// fpw_ = {1, 1, 1};
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
contig_per_thread_ = {1, 2};
// rep_ = {2, 2, 1};
}
order_ = {0, 1};
/* warps per tile */
wpt_ = {1, 1, 1};
// try to make warp-level tiles as square as possible to maximize data re-use
if (tgt->as_nvidia()->sm() < 80) {
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt_;
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
}while(wpt_nm1 != wpt_);
} else {
bool changed = false;
do {
changed = false;
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
break;
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
if (wpt_[0] < shape_[0] / spw_[0]) {
wpt_[0] *= 2;
changed = true;
}
} else {
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
wpt_[1] *= 2;
changed = true;
}
}
} while (changed);
}
/* shape per block */
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
}
/* -------------------------------- *
* Scanline Layout *
* -------------------------------- */
scanline_layout::scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align, target *tgt): distributed_layout(SCANLINE, axes, shape, values, align){
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1;
nts_.resize(shape_.size());
mts_.resize(shape_.size());
bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
std::vector<ir::value*> ptrs;
for(ir::value *v: values)
for(ir::user *usr: v->get_users())
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
if(ptrs.empty() || ptrs[0]->get_type()->get_tile_rank() <= io->get_pointer_operand()->get_type()->get_tile_rank())
ptrs.push_back(io->get_pointer_operand());
}
unsigned i = order_[0];
int contiguous = 1;
for(ir::value* ptr: ptrs){
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
}
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
size /= shape_[i];
num_threads /= mts_[i];
if(is_dot)
nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
for(size_t d = 1; d < shape_.size(); d++){
i = order_[d];
if(d > 1 || !is_dot)
nts_[i] = 1;
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts_[i];
}
shape_per_cta_.resize(shape_.size());
for(size_t d = 0; d < shape_.size(); d++)
shape_per_cta_[d] = mts_[d]*nts_[d];
}
/* -------------------------------- *
* Shared Layout *
* -------------------------------- */
bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
return br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
return false;
else
throw std::runtime_error("unreachable");
}
void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
auto* phi = dynamic_cast<ir::phi_node*>(v);
if(!phi || phi->get_num_incoming() != 2)
return;
ir::basic_block *block_0 = phi->get_incoming_block(0);
ir::basic_block *block_1 = phi->get_incoming_block(1);
ir::instruction *terminator_0 = block_0->get_inst_list().back();
ir::instruction *terminator_1 = block_1->get_inst_list().back();
bool is_latch_0 = is_loop_latch(phi, terminator_0);
bool is_latch_1 = is_loop_latch(phi, terminator_1);
ir::value *value_0 = phi->get_incoming_value(0);
ir::value *value_1 = phi->get_incoming_value(1);
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!(i_0 && !i_1) &&
!(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
!(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
return;
if(is_latch_1)
res.reset(new double_buffer_info_t{value_0, value_1, phi});
if(is_latch_0)
res.reset(new double_buffer_info_t{value_1, value_0, phi});
}
static bool is_smem(ir::value* v) {
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v))
return true;
else
return false;
}
/// param:
/// value_1: next_value
static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::basic_block* bb1,
std::vector<ir::value*>& values_0, ir::value*& value_1) {
ir::value* next = phi;
while (auto cphi = dynamic_cast<ir::phi_node*>(next)) {
// smem from previous bb & phi/smem from current bb
ir::value* c0 = cphi->get_incoming_value(0);
ir::value* c1 = cphi->get_incoming_value(1);
ir::basic_block *cbb0 = cphi->get_incoming_block(0);
ir::basic_block *cbb1 = cphi->get_incoming_block(1);
if (is_smem(c0)) {
assert(cbb0 == bb0);
values_0.push_back(c0);
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
next = phi1;
continue;
} else {
if (is_smem(c1)) {
value_1 = c1;
assert(cbb1 == bb1);
return true;
} else {
return false;
}
}
} else
return false;
}
return false;
}
void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t> &res, int &prev_stages) {
auto* phi = dynamic_cast<ir::phi_node*>(v);
// if the phi node is nested
if (!phi)
return;
ir::basic_block *bb0 = phi->get_incoming_block(0);
ir::basic_block *bb1 = phi->get_incoming_block(1);
std::vector<ir::value*> values_0;
ir::value* value_1;
if (!is_multistage_pipe_phi(phi, bb0, bb1, values_0, value_1))
return;
// double-buffer is a special case
if (values_0.size() == 1)
return;
// compute original values_0 input order
std::map<ir::value*, int> order;
int idx = 0;
for (ir::instruction* instr : *bb0) {
if (std::find(values_0.begin(), values_0.end(), instr) != values_0.end())
order[static_cast<ir::value*>(instr)] = idx++;
}
assert(order.size() == values_0.size() && "order size incorrect");
int curr_stages = values_0.size() + 1;
if (curr_stages > prev_stages) {
res.reset(new N_buffer_info_t{values_0, value_1, phi, order});
prev_stages = curr_stages;
}
}
shared_layout::shared_layout(data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align, target *tgt)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
size_ = 0;
arg_layout_ = arg;
// N-stage buffering
int prev_stages = 0;
for (ir::value *v : values)
extract_N_bufferable(v, N_buffer_, prev_stages);
// double-buffering
if (!N_buffer_)
for(ir::value *v: values)
extract_double_bufferable(v, double_buffer_);
// order
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
order_ = arg_order;
ir::value* dot_a = nullptr;
ir::value* dot_b = nullptr;
ir::value* hmma_dot_a = nullptr;
ir::value* hmma_dot_b = nullptr;
for(ir::value* v: values){
extract_dot_use(v, dot_a, 0);
extract_dot_use(v, dot_b, 1);
extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
}
hmma_dot_a_ = hmma_dot_a;
hmma_dot_b_ = hmma_dot_b;
// Update mma_vec
if (hmma_dot_a_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 0) // need transpose
allow_swizzle_ = false;
} else if (hmma_dot_b_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 1) // need transpose
allow_swizzle_ = false;
}
// size
size_ = ty_->get_primitive_size_in_bits() / 8;
for(auto s: shape_)
size_ *= s;
if(double_buffer_)
size_ *= 2;
if (N_buffer_) {
size_ *= (N_buffer_->firsts.size() + 1);
}
}
int shared_layout::get_num_stages() const {
if (double_buffer_)
return 2;
if (N_buffer_)
return N_buffer_->firsts.size() + 1;
return 1;
}
size_t shared_layout::get_per_stage_elements() const {
return get_per_stage_size()/(ty_->get_primitive_size_in_bits()/8);
}
/* -------------------------------- *
* ---- Layouts Inference Pass ---- *
* -------------------------------- */
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt)
: axes_(axes), align_(align), num_warps_(num_warps), tgt_(tgt){ }
void layouts::connect(ir::value *x, ir::value *y) {
if(x == y)
return;
if(!x->get_type()->is_block_ty())
return;
if(!y->get_type()->is_block_ty())
return;
std::vector<int> x_axes = axes_->get(x);
std::vector<int> y_axes = axes_->get(y);
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
std::set<int> common;
std::set_intersection(sx_axes.begin(), sx_axes.end(),
sy_axes.begin(), sy_axes.end(),
std::inserter(common, common.begin()));
graph_.add_edge(x, x);
graph_.add_edge(y, y);
if(!common.empty())
graph_.add_edge(x, y);
}
void layouts::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
// if(layouts_.find(id) != layouts_.end())
// return;
auto it_hmma_c = std::find_if(values.begin(), values.end(),
[&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
auto cmp = [](ir::value* x, ir::value *y) {
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
return xx < yy;
};
std::vector<ir::value*> lvalue = values;
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_block_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v);
});
// type
if(it_hmma_c != values.end()){
ir::instruction *dot = (ir::instruction*)*it_hmma_c;
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
create(groups_.at(a), values_.at(groups_.at(a)));
create(groups_.at(b), values_.at(groups_.at(b)));
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
(shared_layout*)layouts_.at(groups_.at(a)),
(shared_layout*)layouts_.at(groups_.at(b)),
dot);
}
else if(it_cts != values.end()){
ir::instruction *cts = (ir::instruction*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
}
else{
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
}
}
void layouts::run(ir::module &mod) {
// make graph
graph_.clear();
layouts_.clear();
groups_.clear();
ir::for_each_instruction(mod, [this](ir::instruction* i) {
make_graph(i);
});
// connected components
graph_.connected_components(&values_, &groups_);
// create layouts
for(const auto& x: values_)
create(x.first, x.second);
// create temporaries
size_t id = values_.size();
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
id++;
ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis();
// shape
auto shapes = arg->get_type()->get_block_shapes();
scanline_layout *layout = get(arg)->to_scanline();
shapes[axis] = layout->mts(axis);
// create layout
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[red] = id;
}
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
id++;
size_t dim = val->get_type()->get_tile_rank();
ir::type::block_shapes_t shape(dim);
for(size_t k = 0; k < dim; k++){
shape[k] = std::max(in_layout->shape_per_cta(k),
out_layout->shape_per_cta(k));
}
auto in_ord = in_layout->get_order();
auto out_ord = out_layout->get_order();
int in_vec = in_layout->contig_per_thread(in_ord[0]);
int out_vec = out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec);
shape[out_ord[0]] += pad;
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[val] = id;
}
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[atom] = id;
}
});
}
}
}
}

View File

@@ -1,59 +0,0 @@
#include <climits>
#include <iostream>
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace analysis{
void liveness::run(ir::module &mod) {
intervals_.clear();
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
for(ir::function *fn: mod.get_function_list()){
slot_index index = 0;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()){
index += 1;
indices.insert({instr, index});
}
}
// create live intervals
for(auto &x: layouts_->get_all()) {
shared_layout* layout = x.second->to_shared();
if(!layout)
continue;
// users
std::set<ir::user*> users;
for(ir::value *v: layout->get_values()){
for(ir::user *u: v->get_users())
users.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
for(ir::value *v: layout->get_values())
if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v));
unsigned end = 0;
for(ir::user *u: users)
if(indices.find(u) != indices.end())
end = std::max(end, indices.at(u));
if(end == 0)
end = start + 1;
intervals_[layout] = segment{start, end};
}
}
}
}
}

View File

@@ -1,61 +0,0 @@
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/target.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton{
namespace codegen{
namespace analysis{
void swizzle::run(ir::module &) {
per_phase_.clear();
max_phase_.clear();
for(auto &x: layouts_->get_all()){
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
if(!layout)
continue;
ir::value* mma_dot_a = layout->hmma_dot_a();
ir::value* mma_dot_b = layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
continue;
}
auto ord = layout->get_order();
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
if(!in_layout)
continue;
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
int inner = mma_dot_a ? 0 : 1;
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
if(mma_dot_a)
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
else
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
}
else {
if (!layout->allow_swizzle()) {
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
} else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
}
}
}
}
}
}
}

View File

@@ -1,86 +0,0 @@
#include "triton/codegen/pass.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/codegen/transform/prefetch.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
namespace triton {
namespace codegen {
// TODO:
// There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
int cc, int num_warps, int num_stages, int& shared_static) {
// generate llvm code
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target);
codegen::analysis::allocation allocation(&liveness);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target, &layouts);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::prefetch prefetch_s(target);
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
// run passes
dce.run(ir);
peephole.run(ir);
dce.run(ir);
pipeline.run(ir);
dce.run(ir);
disassociate.run(ir);
dce.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
peephole.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
coalesce.run(ir);
dce.run(ir);
align.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
dce.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
peephole.run(ir);
dce.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
swizzle.run(ir);
liveness.run(ir);
allocation.run(ir);
prefetch_s.run(ir);
barriers.run(ir);
isel.visit(ir, *llvm);
shared_static = allocation.allocated_size();
return llvm;
}
} // namespace codegen
} // namespace triton

File diff suppressed because it is too large Load Diff

View File

@@ -1,173 +0,0 @@
#include "triton/codegen/target.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/IRBuilder.h"
#include <iostream>
using namespace llvm;
namespace triton{
namespace codegen{
// base
nvidia_cu_target* target::as_nvidia() {
return dynamic_cast<nvidia_cu_target*>(this);
}
bool target::is_gpu() const {
return is_gpu_;
}
// AMD
void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
fn->setCallingConv(CallingConv::AMDGPU_KERNEL);
}
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {});
}
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
throw std::runtime_error("not implemented");
}
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workgroup_id_x,
Intrinsic::amdgcn_workgroup_id_y,
Intrinsic::amdgcn_workgroup_id_z
};
Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {});
return group_id;
}
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented on AMD");
}
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workitem_id_x,
Intrinsic::amdgcn_workitem_id_y,
Intrinsic::amdgcn_workitem_id_z
};
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_local_id, {});
}
// NVIDIA
void nvidia_cu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn){
// set metadata
Metadata *md_args[] = {
ValueAsMetadata::get(fn),
MDString::get(ctx, "kernel"),
ValueAsMetadata::get(builder.getInt32(1))
};
module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
}
Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_barrier0);
return builder.CreateCall(barrier, {});
}
Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl);
return builder.CreateCall(barrier, {});
}
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
return result;
}
Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> cta_ids = {
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
};
Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {});
return cta_id;
}
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::nvvm_read_ptx_sreg_tid_x,
Intrinsic::nvvm_read_ptx_sreg_tid_y,
Intrinsic::nvvm_read_ptx_sreg_tid_z
};
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_local_id, {});
}
Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
};
return builder.CreateIntrinsic(ids[ax], {}, {});
}
// CPU
void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
// normal cpu functions can be kernels
}
Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
// no barrier on CPU
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) {
// no barrier on CPU
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
const Function *fn = builder.GetInsertBlock()->getParent();
size_t num_params = fn->getFunctionType()->getNumParams();
static std::array<const Argument*, 3> ids = {
fn->arg_begin() + num_params - 3,
fn->arg_begin() + num_params - 2,
fn->arg_begin() + num_params - 1
};
return (Argument*)ids[ax];
}
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented");
}
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
return result;
}
Value* cpu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
return builder.getInt32(0);
}
}
}

View File

@@ -1,133 +0,0 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h"
namespace triton {
namespace codegen{
namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
: align_(align), layout_(layouts) { }
// simplify layout conversions using the following simple rules:
// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y))
//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){
// ir::value* _op = inst->get_operand(0);
// ir::instruction* op = dynamic_cast<ir::instruction*>(_op);
// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma();
// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma();
// std::cout << 1 << std::endl;
// // i must be layout conversion instruction
// if(!mma_in && !mma_out)
// return inst;
// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT;
// if((mma_in || mma_out) && is_op_cvt &&
// (layout_->get(inst) == layout_->get(op->get_operand(0))))
// return op->get_operand(0);
// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y))
// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR)
// return inst;
// std::cout << 1 << std::endl;
// for(size_t i = 0; i < op->get_num_operands(); i++){
// ir::value* arg_i = op->get_operand(i);
// builder.set_insert_point(op);
// // create new layout transform
// ir::instruction* new_arg_i = inst->clone();
// builder.insert(new_arg_i);
// // set the right args
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder));
// }
// std::cout << 2 << std::endl;
// return op;
//}
void coalesce::run(ir::module &mod) {
ir::builder& builder = mod.get_builder();
// add layout conversion instructions
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
// coalesce before store
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
if(ir::value* op = i->get_operand(1))
if(op->get_type()->is_block_ty())
if(layout_->get(op)->to_mma()){
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
builder.set_insert_point(i);
builder.insert(new_op);
i->replace_uses_of_with(op, new_op);
}
// uncoalesce after load
if(auto x = dynamic_cast<ir::load_inst*>(i))
if(x->get_type()->is_block_ty())
if(x->get_type()->get_tile_rank()==2)
if(layout_->get(x)->to_mma()){
builder.set_insert_point_after(x);
ir::instruction* new_x = ir::cvt_layout_inst::create(x);
builder.insert(new_x);
x->replace_all_uses_with(new_x);
new_x->replace_uses_of_with(new_x, x);
}
}
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
// re-arrange scanline to promote memory coalescing
if(auto x = dynamic_cast<ir::store_inst*>(i)){
ir::value* ptr = x->get_pointer_operand();
ir::value* val = x->get_value_operand();
auto out_contig = align_->contiguous(ptr);
auto val_inst = dynamic_cast<ir::instruction*>(val);
if(!val_inst)
break;
if(dynamic_cast<ir::cvt_layout_inst*>(val))
break;
std::vector<unsigned> in_contig;
std::vector<ir::instruction*> queue = {val_inst};
std::set<ir::instruction*> seen;
std::vector<ir::io_inst*> ios;
while(!queue.empty()){
ir::instruction* curr = queue.back();
seen.insert(curr);
queue.pop_back();
if(auto dot_inst = dynamic_cast<ir::dot_inst*>(curr))
break;
if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
in_contig = align_->contiguous(io_inst->get_pointer_operand());
break;
}
for(ir::value* op: curr->ops()){
auto inst_op = dynamic_cast<ir::instruction*>(op);
if(!inst_op || seen.find(inst_op) != seen.end())
continue;
if(!op->get_type()->is_block_ty() ||
!val->get_type()->is_block_ty())
continue;
if(op->get_type()->get_tile_num_elements() ==
val->get_type()->get_tile_num_elements())
queue.push_back(inst_op);
}
}
if(in_contig.size() <= 1 || out_contig==in_contig)
continue;
builder.set_insert_point_after(val_inst);
auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst));
x->replace_uses_of_with(val_inst, new_val);
}
}
}
}
}
}

View File

@@ -1,97 +0,0 @@
#include "triton/codegen/transform/cts.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace transform{
inline bool is_shmem_op(ir::instruction* i, int op) {
if(i->get_id() == ir::INST_DOT)
return op==0 || op==1;
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
return op==0;
if(i->get_id() == ir::INST_TRANS)
return op==0;
return false;
}
inline bool is_shmem_res(ir::value* v){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return false;
if(i->get_id() == ir::INST_TRANS)
return true;
if(i->get_id() == ir::INST_COPY_TO_SHARED)
return true;
if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
return true;
return false;
}
// run pass on module
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction
if(!i) {
builder.set_insert_point(parent);
ir::value *copy;
if(to_shared)
copy = builder.create_copy_to_shared(x);
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
return;
}
// phi node
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
return;
}
// already in shared memory
if(to_shared && is_shmem_res(i))
return;
// copy
builder.set_insert_point_after(i);
ir::value *copy;
if(to_shared){
copy = builder.create_copy_to_shared(x);
}
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
}
void cts::run(ir::module &mod) {
// Add shared copies
ir::builder &builder = mod.get_builder();
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
size_t num_op = i->get_num_operands();
// copy to shared operands
for(size_t k = 0; k < num_op; k++)
if(is_shmem_op(i, k)){
add_copy(i, i->get_operand(k), builder, true);
}
// copy from shared operands
for(size_t k = 0; k < num_op; k++)
if(!dynamic_cast<ir::phi_node*>(i) &&
!is_shmem_op(i,k) &&
is_shmem_res(i->get_operand(k))){
add_copy(i, i->get_operand(k), builder, false);
}
}
}
}
}
}
}

View File

@@ -1,75 +0,0 @@
#include "triton/codegen/transform/dce.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
void dce::run(ir::module &mod) {
std::list<ir::instruction*> work_list;
std::set<ir::instruction*> marked;
// initialize work-list
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// iterate through blocks
for(ir::basic_block *block: rpo)
for(ir::instruction *i: block->get_inst_list()){
switch(i->get_id()){
case ir::INST_RETURN:
case ir::INST_UNCOND_BRANCH:
case ir::INST_COND_BRANCH:
case ir::INST_UNMASKED_STORE:
case ir::INST_MASKED_STORE:
case ir::INST_ATOMIC_CAS:
case ir::INST_ATOMIC_RMW:
case ir::INST_ATOMIC_EXCH:
case ir::INST_BARRIER: {
work_list.push_back(i);
marked.insert(i);
break;
}
default:
break;
}
}
}
// mark -- ignore branches
while(!work_list.empty()){
ir::instruction* current = work_list.back();
work_list.pop_back();
// mark instruction operands
for(ir::value* op: current->ops()) {
if(auto *i = dynamic_cast<ir::instruction*>(op)){
if(marked.insert(i).second)
work_list.push_back(i);
}
}
// TODO: mark last intstruction of current's reverse-dominance frontier
}
// sweep -- delete non-branch unmarked instructions
std::vector<ir::instruction*> to_delete;
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// iterate through blocks
for(ir::basic_block *block: rpo)
for(ir::instruction *i: block->get_inst_list()){
if(marked.find(i) == marked.end())
to_delete.push_back(i);
}
}
// delete
for(ir::instruction* i: to_delete)
i->erase_from_parent();
}
}
}
}

View File

@@ -1,62 +0,0 @@
#include "triton/codegen/transform/disassociate.h"
#include "triton/ir/utils.h"
#include "triton/ir/instructions.h"
#include "triton/ir/builder.h"
#include "triton/ir/module.h"
#include <iostream>
namespace triton {
namespace codegen{
namespace transform{
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
std::set<ir::value*>& seen) {
if (dynamic_cast<ir::phi_node*>(root))
return root;
if(!seen.insert(root).second)
return root;
if(!root->get_type()->is_block_ty())
return root;
bld.set_insert_point(root);
ir::instruction *new_root = bld.insert(root->clone());
for(ir::value *op: root->ops()){
ir::instruction *i = dynamic_cast<ir::instruction*>(op);
if(!i || i->get_id() == ir::INST_REDUCE)
continue;
ir::instruction* new_op = rematerialize(bld, i, seen);
new_root->replace_uses_of_with(op, new_op);
}
return new_root;
}
void disassociate::run(ir::module &mod) {
ir::builder &bld = mod.get_builder();
// ir::for_each_instruction(mod, [&](ir::instruction *i){
// bld.set_insert_point(i);
// for(ir::value* op: i->ops()){
// auto reshape = dynamic_cast<ir::make_range*>(op);
// if(!reshape)
// continue;
// ir::instruction* new_op = bld.insert(reshape->clone());
// i->replace_uses_of_with(op, new_op);
// }
// });
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(dynamic_cast<ir::reshape_inst*>(i) || dynamic_cast<ir::splat_inst*>(i)){
std::set<ir::value*> seen;
ir::instruction* new_i = rematerialize(bld, i, seen);
i->replace_all_uses_with(new_i);
}
});
}
}
}
}

View File

@@ -1,244 +0,0 @@
#include <vector>
#include <set>
#include <algorithm>
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/prefetch.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
if (analysis::double_buffer_info_t* info = layout->get_double_buffer())
return group_of(info->first, async_write);
else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) {
if (v == info->phi)
return group_of(info->firsts[0], async_write);
else // prefetched value
return group_of(info->firsts[1], async_write);
}
std::vector<int> groups(phi->get_num_operands());
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
return *std::max_element(groups.begin(), groups.end());
}
else{
if(layouts_->has_tmp(v))
return async_write.size() - 1;
auto it = std::find(async_write.begin(), async_write.end(), v);
return std::distance(async_write.begin(), it);
}
}
inline bool membar::intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout) {
if(!a_layout || !b_layout)
return false;
int a_start = alloc_->offset(a_layout);
int a_end = a_start + a_layout->get_size();
int b_start = alloc_->offset(b_layout);
int b_end = b_start + b_layout->get_size();
if(a_start < b_end || b_start < a_end)
return true;
return false;
}
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
val_set_t ret;
for(ir::value* a: as){
if(!a->get_type()->is_block_ty())
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
for(ir::value* b: bs){
if(!b->get_type()->is_block_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
if(intersect_with(a_layout, b_layout) ||
intersect_with(a_layout, b_tmp) ||
intersect_with(a_tmp, b_layout) ||
intersect_with(a_tmp, b_tmp))
ret.insert(b);
}
}
return ret;
}
bool membar::check_safe_war(ir::instruction* i) {
bool is_i_shared_block = i->get_type()->is_block_ty() &&
layouts_->get(i)->to_shared();
bool is_i_double_buffered = is_i_shared_block &&
layouts_->get(i)->to_shared()->get_double_buffer();
bool is_i_n_buffered = is_i_shared_block &&
layouts_->get(i)->to_shared()->get_N_buffer();
if (is_i_double_buffered || is_i_n_buffered) {
// with async copy & prefetch_s disabled, WARs are not safe
if (dynamic_cast<ir::masked_load_async_inst*>(i) && !prefetch_->is_prefetched(i))
return false;
else
return true;
}
return false;
}
void membar::transfer(ir::basic_block *block,
val_vec_t& async_write,
val_set_t& sync_write,
val_set_t& sync_read,
std::set<ir::value*>& safe_war,
bool& inserted, ir::builder& builder) {
std::vector<ir::async_wait_inst*> async_waits;
ir::basic_block::inst_list_t instructions = block->get_inst_list();
for(ir::instruction *i: instructions){
if(dynamic_cast<ir::phi_node*>(i))
continue;
if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
dynamic_cast<ir::masked_load_async_inst*>(i)){
async_write.push_back(i);
}
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
sync_write.insert(i);
ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
// Get shared memory reads
std::set<ir::value*> read;
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
[&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();});
if(layouts_->has_tmp(i))
read.insert(i);
// RAW (async)
val_set_t tmp;
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
if(intersect_with(read, tmp).size()){
std::vector<int> groups(read.size());
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
int N = *std::max_element(groups.begin(), groups.end());
if(N < async_write.size()){
builder.set_insert_point(i);
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
async_waits.push_back(async_wait);
}
}
// RAW, WAR
bool is_safe_war = check_safe_war(i);
// WAR barrier is not required when data is double-buffered
if(!intersect_with(read, sync_write).empty() ||
(!intersect_with({i}, sync_read).empty() && !is_safe_war)) {
builder.set_insert_point(i);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
}
// update state of asynchronous copies
if(async_wait){
int N = async_write.size() - async_wait->get_N();
async_write.erase(async_write.begin(), async_write.begin() + N);
}
// all the copy_to_shared and read from shared are synchronized after barrier
if(barrier){
sync_write.clear();
sync_read.clear();
}
sync_read.insert(read.begin(), read.end());
}
// coalesce barriers
// fixme: to support more general cases
if (async_waits.size() == 2) {
// (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;)
for (int idx=0; idx<async_waits.size()-1; ++idx) {
ir::async_wait_inst *first_async_wait = async_waits[idx];
std::vector<ir::instruction*> to_erase;
ir::basic_block::inst_list_t instructions = block->get_inst_list();
for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){
ir::instruction *i = *iter;
if (static_cast<ir::instruction*>(first_async_wait) == i) {
// peak next 5 instructions
auto peak_iter = std::next(iter);
if (std::distance(peak_iter, instructions.end()) >= 5) {
auto first_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
auto first_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter++);
auto second_async_wait = dynamic_cast<ir::async_wait_inst*>(*peak_iter++);
auto second_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
auto second_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter);
if (first_bar && first_pf && second_async_wait && second_bar && second_pf) {
int first_n = first_async_wait->get_N();
int second_n = second_async_wait->get_N();
to_erase.push_back(second_async_wait);
to_erase.push_back(second_bar);
first_async_wait->set_N(second_n);
}
} else
break;
for (ir::instruction *i : to_erase)
block->erase(i);
}
}
}
}
}
void membar::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// extract phi-node associates with double-buffered
// shared-memory copies. These can be read from and written to
// without needing synchronization
std::set<ir::value*> safe_war;
for(const auto& x: layouts_->get_all()){
analysis::shared_layout* layout = x.second->to_shared();
if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer())
continue;
for(ir::value *v: layout->get_values())
if(v != layout->get_double_buffer()->phi){
safe_war.insert(v);
}
}
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, val_vec_t> async_writes;
std::map<ir::basic_block*, val_set_t> sync_writes;
std::map<ir::basic_block*, val_set_t> sync_reads;
std::list<ir::value *> pipelined;
bool inserted;
do{
inserted = false;
// find barrier location
for(ir::basic_block *block: rpo){
// join inputs
val_vec_t async_write;
val_set_t sync_write;
val_set_t sync_read;
val_set_t tmp;
for(ir::basic_block* pred: block->get_predecessors()){
for(ir::value* v: async_writes[pred])
if(tmp.insert(v).second)
async_write.push_back(v);
sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
}
transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
async_writes[block] = async_write;
sync_writes[block] = sync_write;
sync_reads[block] = sync_read;
}
}while(inserted);
}
}
}
}
}

View File

@@ -1,309 +0,0 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/analysis/layout.h"
namespace triton {
namespace codegen{
namespace transform{
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
const std::vector<int>& perm) {
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
// transpose operands
std::vector<ir::value*> incs;
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm));
// create phi for transposed values
builder.set_insert_point(phi);
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
result->add_incoming(incs[n], phi->get_incoming_block(n));
return result;
}
else if(auto i = dynamic_cast<ir::instruction*>(value)){
ir::basic_block* block = i->get_parent();
auto it = std::find(block->begin(), block->end(), i);
it++;
builder.set_insert_point(it);
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
trans->set_operand(0, i);
return trans;
}
return nullptr;
}
bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
auto trans = dynamic_cast<ir::trans_inst*>(value);
if(!trans)
return false;
auto users = trans->get_users();
auto ops = trans->ops();
if(users.size() > 1 || ops.size() > 1)
return false;
ir::value* op = *ops.begin();
// trans(phi) -> phi(trans(), trans()...)
auto* phi = dynamic_cast<ir::phi_node*>(op);
if(!phi)
return false;
ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm());
if(!new_phi)
return false;
trans->replace_all_uses_with(new_phi);
return true;
}
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
// dot(a, b, c) + d -> dot(a, b, c + d)
// d + dot(a, b, c) -> dot(a, b, c + d)
auto add = dynamic_cast<ir::binary_operator*>(value);
if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
ir::value *lhs = add->get_operand(0);
ir::value *rhs = add->get_operand(1);
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
ir::dot_inst *rhs_dot = dynamic_cast<ir::dot_inst*>(rhs);
if(!lhs_dot && !rhs_dot)
return false;
ir::dot_inst *dot = lhs_dot ? lhs_dot : rhs_dot;
ir::value *other = (dot == lhs) ? rhs : lhs;
ir::value *acc = dot->get_operand(2);
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
ir::constant *_0 = nullptr;
if(splat)
_0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
if(!_0)
return false;
if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
if (fp_0->get_value() != 0.0)
return false;
if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
if (int_0->get_value() != 0)
return false;
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
builder.set_insert_point(add);
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
add->replace_all_uses_with(new_dot);
return true;
}
return false;
}
//bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
// auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
// if(cfs) {
// ir::value *arg = cfs->get_operand(0);
// ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
// if(!cts)
// return false;
// cfs->replace_all_uses_with(cts->get_operand(0));
// return true;
// }
//}
bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){
auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value);
if(!copy_to_shared)
return false;
ir::value *arg = copy_to_shared->get_operand(0);
ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg);
if(!ld)
return false;
builder.set_insert_point(copy_to_shared);
ir::value *ptr = ld->get_pointer_operand();
ir::value *msk = ld->get_mask_operand();
ir::value *val = ld->get_false_value_operand();
analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
return false;
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
// return true;
}
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
auto x = dynamic_cast<ir::reduce_inst*>(value);
if(!x)
return false;
ir::value *arg = x->get_operand(0);
auto shapes = arg->get_type()->get_block_shapes();
if(shapes[x->get_axis()] == 1){
builder.set_insert_point(x);
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes());
x->replace_all_uses_with(new_red);
return true;
}
return false;
}
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
auto binop = dynamic_cast<ir::binary_operator*>(value);
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
ir::value *lhs = binop->get_operand(0);
ir::value *rhs = binop->get_operand(1);
ir::constant_int *_1_lhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_lhs = cst;
}
ir::constant_int *_1_rhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_rhs = cst;
}
if(_1_lhs){
binop->replace_all_uses_with(rhs);
return true;
}
else if(_1_rhs){
binop->replace_all_uses_with(lhs);
return true;
}
}
return false;
}
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
if(!x)
return false;
auto y = dynamic_cast<ir::getelementptr_inst*>(x->get_pointer_operand());
if(!y)
return false;
auto idx = *y->idx_begin();
auto z = dynamic_cast<ir::binary_operator*>(idx);
if(!z)
return false;
bool is_sub = z->get_op() == ir::binary_op_t::Sub;
auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
bool is_lhs_0 = lhs && (lhs->get_value()==0);
bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();
if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){
x->replace_all_uses_with(y->get_pointer_operand());
return true;
}
return false;
}
bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& builder){
auto select = dynamic_cast<ir::select_inst*>(value);
if(!select)
return false;
auto if_value = dynamic_cast<ir::masked_load_inst*>(select->get_if_value_op());
if(!if_value)
return false;
if(select->get_pred_op() != if_value->get_mask_operand())
return false;
builder.set_insert_point(select);
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
if_value->get_mask_operand(),
select->get_else_value_op(),
if_value->get_cache_modifier(),
if_value->get_eviction_policy(),
if_value->get_is_volatile());
select->replace_all_uses_with(new_load);
return true;
}
bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){
auto cvt = dynamic_cast<ir::cvt_layout_inst*>(value);
if(!cvt)
return false;
ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0));
if(!op)
return false;
// // convert(elementwise(x, y)) = elementwise(convert(x), convert(y))
// if(op->get_id() == ir::INST_BINOP){
// for(size_t i = 0; i < op->get_num_operands(); i++){
// ir::value* arg_i = op->get_operand(i);
// builder.set_insert_point(op);
// // create new layout transform
// ir::instruction* new_arg_i = cvt->clone();
// layouts_->copy(new_arg_i, op);
// builder.insert(new_arg_i);
// // set the right args
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
// op->replace_uses_of_with(arg_i, new_arg_i);
// }
// cvt->replace_all_uses_with(op);
// return true;
// }
auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op);
if(!cvt_op)
return false;
// convert1(convert2(x)) if convert1 is the inverse of convert2
ir::value* op_op = cvt_op->get_operand(0);
if(layouts_->has(cvt) && layouts_->has(op_op) &&
layouts_->get(cvt) && layouts_->get(op_op)){
cvt->replace_all_uses_with(op_op);
return true;
}
return false;
}
void peephole::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// keep track of whether any modification was made
std::set<ir::value*> seen;
size_t n_seen;
// rewrite dots first
do{
n_seen = seen.size();
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(seen.find(i) != seen.end())
continue;
bool was_modified = rewrite_dot(i, builder);
if(was_modified){
seen.insert(i);
}
}
}while(seen.size() != n_seen);
// rewrite other ops
seen.clear();
do{
n_seen = seen.size();
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(seen.find(i) != seen.end())
continue;
bool was_modified = false;
was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
// was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
// TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD
// was_modified = was_modified || rewrite_select_masked_load(i, builder);
was_modified = was_modified || rewrite_cvt_layout(i, builder);
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(was_modified)
seen.insert(i);
}
}while(seen.size() != n_seen);
}
}
}
}

View File

@@ -1,330 +0,0 @@
#include <iostream>
#include <algorithm>
#include "triton/codegen/transform/pipeline.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i || i->get_parent() != block)
return;
if(i->get_id()==ir::INST_PHI)
return;
ret.push_back(i);
for(ir::user* u: i->get_users())
recursive_deps(u, block, ret);
}
void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
auto instr = dynamic_cast<ir::instruction*>(cond);
for (auto op : instr->ops()) {
if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) {
phis.insert(phi_op);
return;
}
if (dynamic_cast<ir::instruction*>(op))
get_induction_vars(op, phis);
}
}
/// assume incoming block is 1
ir::value* rematerialize_vals(ir::builder& builder, ir::basic_block* block, ir::value* v,
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i || i->get_parent() != block)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
throw std::runtime_error("Don't have that phi node\n");
return prev_phi_vals.at(phi);
}
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize_vals(builder, block, op, prev_phi_vals));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
ir::value* rematerialize(ir::builder& builder, ir::basic_block* block,
ir::value* v, size_t phi_idx){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i || i->get_parent() != block)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
return phi->get_incoming_value(phi_idx);
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize(builder, block, op, phi_idx));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
/// moving the prev phi vals to the next iteration
std::map<ir::phi_node*, ir::value*> update_prev_phi_vals(
ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
std::map<ir::phi_node*, ir::value*> next_phi_vals;
for (auto &[phi, val] : prev_phi_vals) {
next_phi_vals[phi] = rematerialize_vals(builder, block, phi->get_incoming_value(1), prev_phi_vals);
}
return next_phi_vals;
}
void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& load_ivs,
std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
for (auto& [phi, val] : load_ivs) {
if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
ir::value* next_k = rematerialize_vals(builder, block, phi->get_incoming_value(1), load_ivs);
assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
new_phi->add_incoming(next_k, phi->get_incoming_block(1));
// cache next_k (to be used by next_mask)
next_load_ivs[phi] = next_k;
} else
throw std::runtime_error("must be phi");
}
}
struct pipeline_info_t {
ir::load_inst* load;
ir::phi_node* ptr;
ir::dot_inst* dot;
pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot)
: load(load), ptr(ptr), dot(dot) {}
};
void pipeline::run(ir::module &mod) {
if (num_stages_ <= 1)
return;
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
// - the pointer is a phi node that references a value
// in its basic block (i.e., pointer induction variable)
// - the load has only a single use in a dot instruction
// As more use cases become apparent, this pass will be improved
std::vector<pipeline_info_t> to_pipeline;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
auto users = load->get_users();
auto dot = dynamic_cast<ir::dot_inst*>(*users.begin());
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
&& users.size() == 1 && dot)
to_pipeline.push_back({load, ptr, dot});
}});
// do the pipelining
std::vector<ir::phi_node*> new_loads;
ir::builder &builder = mod.get_builder();
const int num_stages = num_stages_;
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
for(auto info: to_pipeline){
ir::load_inst* load = info.load;
ir::phi_node* ptr = info.ptr;
ir::basic_block* block = load->get_parent();
ir::basic_block* header = block->get_predecessors()[0];
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
assert(block_br);
assert(header_br);
ir::type* ty = load->get_type();
// multi-stage pipe
if (has_copy_async_ && num_stages > 2) {
ir::value* header_cond = header_br->get_cond();
ir::value* block_cond = block_br->get_cond();
// 1. collect induction variables
std::set<ir::phi_node*> induction_vars;
get_induction_vars(block_cond, induction_vars);
std::vector<ir::value*> first_ptrs(num_stages-1);
std::vector<ir::value*> first_loads(num_stages-1);
std::vector<ir::value*> first_masks(num_stages-1);
std::vector<ir::value*> loop_conds(num_stages-1);
std::map<ir::phi_node*, ir::value*> prev_phi_vals;
// initialize prev_phi_vals
// Add all phi nodes. The following DCE pass will delete dead ones.
for (ir::instruction *instr : block->get_inst_list())
if (auto *phi = dynamic_cast<ir::phi_node*>(instr))
if (phi->get_incoming_block(1) == block)
prev_phi_vals[phi] = phi->get_value_for_block(header);
builder.set_insert_point(header->get_inst_list().back());
first_ptrs[0] = ptr->get_value_for_block(header);
loop_conds[0] = header_cond;
first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
ir::value* false_value = nullptr;
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask =rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals) ;
ir::value* remat_false_value =
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[0] = builder.create_and(first_masks[0], remat_mask);
false_value = remat_false_value;
} else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
for (int stage = 1; stage < num_stages-1; ++stage) {
// mask is the loop condition of the previous iteration
loop_conds[stage] = rematerialize_vals(builder, block, block_cond, prev_phi_vals);
prev_phi_vals = update_prev_phi_vals(builder, block, prev_phi_vals);
first_ptrs[stage] = rematerialize_vals(builder, block, ptr, prev_phi_vals);
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals);
ir::value* remat_false_value =
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
false_value = remat_false_value;
}
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
}
// create new phis for induction variables
builder.set_insert_point(block->get_first_non_phi());
std::map<ir::phi_node*, ir::value*> load_ivs;
std::map<ir::phi_node*, ir::value*> next_load_ivs;
for (auto& [iv, val] : prev_phi_vals) {
ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
pn->add_incoming(prev_phi_vals[iv], header);
load_ivs[iv] = pn;
}
// add incoming for phis & update next_load_ivs
finalize_iv_vals(builder, block, load_ivs, next_load_ivs);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
// ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_ptr = rematerialize_vals(builder, block, ptr->get_value_for_block(block), load_ivs);
ir::value* next_mask = builder.create_splat(
rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), next_load_ivs);
// TODO: false may depends on some other phi nodes
ir::value* remat_false_value =
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), next_load_ivs);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// phi node
ptr->set_incoming_value(0, first_ptrs.back());
builder.set_insert_point(block->get_first_non_phi());
// nested phis for load
std::vector<ir::phi_node*> new_load_phis(num_stages-1);
for (auto& pn : new_load_phis)
pn = builder.create_phi(ty, 2);
for (int i=0; i<num_stages-2; ++i) {
new_load_phis[i]->add_incoming(first_loads[i], header);
new_load_phis[i]->add_incoming(new_load_phis[i+1], block);
}
new_load_phis.back()->add_incoming(first_loads.back(), header);
new_load_phis.back()->add_incoming(next_load, block);
load->replace_all_uses_with(new_load_phis.front());
new_loads.push_back(new_load_phis.back());
// record first_loads to reorder them
preheader_loads.push_back({new_load_phis.front(), first_loads});
} else {
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 0);
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 0);
first_mask = builder.create_and(first_mask, remat_mask);
false_value = remat_false_value;
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 1);
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 1);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);
new_load->add_incoming(first_load, header);
new_load->add_incoming(next_load, block);
load->replace_all_uses_with(new_load);
new_loads.push_back(new_load);
}
}
// try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to
// a0, b0, a1, b1, ...
if (!preheader_loads.empty()) {
ir::basic_block* header = preheader_loads.begin()->first->get_incoming_block(0);
builder.set_insert_point(header->get_inst_list().back());
for (int i=1; i<num_stages-1; ++i) {
for (auto iter = preheader_loads.begin(); iter != preheader_loads.end(); ++iter) {
ir::instruction* original_load = static_cast<ir::instruction*>(iter->second.at(i));
ir::instruction* moved_load = original_load->clone();
builder.insert(moved_load);
original_load->replace_all_uses_with(moved_load);
}
}
}
// try to move dot_inst after loads
// for better overlap of io and compute
struct move_config_t{
std::vector<ir::instruction*> insts;
ir::load_inst* dst;
};
std::vector<move_config_t> to_move(to_pipeline.size());
if(has_copy_async_){
for (size_t idx = 0; idx < to_pipeline.size(); ++idx) {
auto info = to_pipeline[idx];
ir::load_inst* load = info.load;
ir::phi_node* ptr = info.ptr;
ir::dot_inst* dot = info.dot;
ir::basic_block* bb = dot->get_parent();
recursive_deps(dot, bb, to_move[idx].insts);
to_move[idx].dst = load;
}
for(auto& move_config: to_move){
builder.set_insert_point_after(move_config.dst);
for(ir::instruction* i: move_config.insts){
i->get_parent()->erase(i);
builder.insert(i);
}
}
}
}
}
}
}

View File

@@ -1,133 +0,0 @@
#include "triton/codegen/transform/prefetch.h"
#include "triton/codegen/target.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
#include "triton/ir/print.h"
#include <iostream>
#include <vector>
#include <algorithm>
namespace triton::codegen::transform {
/// find defs till phis
static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::instruction*> &ret) {
ir::instruction *i = dynamic_cast<ir::instruction*>(v);
if (!i || i->get_parent() != bb)
return;
if (i->get_id() == ir::INST_PHI)
return;
ret.push_back(i);
for (ir::value *op : i->ops())
recursive_defs(op, bb, ret);
}
void prefetch::run(ir::module &mod) {
// 1. collect dots that can be prefethced
std::vector<ir::dot_inst*> to_prefetch;
ir::for_each_instruction(mod, [&](ir::instruction *i) {
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
// Now only do prefetching when dot is using tensor cores
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8)
&& dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8)
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
)
)
return;
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
if (a && a->get_incoming_block(1) == a->get_parent() &&
b && b->get_incoming_block(1) == b->get_parent())
to_prefetch.push_back(dot);
}
});
assert(to_prefetch.size() <=1 && "Don't know what to do with multiple dots");
ir::builder &builder = mod.get_builder();
// 2. do the prefetching
for (ir::dot_inst* dot : to_prefetch) {
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
assert(a->get_incoming_block(0) == b->get_incoming_block(0));
ir::basic_block *loop_header = a->get_incoming_block(0);
ir::basic_block *loop_body = a->get_parent();
// mark as prefetched
dot->set_prefetched(true);
// 1. in the loop header (first iteration)
builder.set_insert_point(loop_header->get_inst_list().back());
assert(a && b);
builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
// 2. at the end of the loop body (next iteration)
builder.set_insert_point(loop_body->get_inst_list().back());
builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
prefetched_vals_.insert(a->get_incoming_value(0));
prefetched_vals_.insert(b->get_incoming_value(0));
// nested phis
ir::value* next_a = a->get_incoming_value(1);
while (auto* next_a_phi = dynamic_cast<ir::phi_node*>(next_a)) {
prefetched_vals_.insert(next_a_phi->get_incoming_value(0));
next_a = next_a_phi->get_incoming_value(1);
}
prefetched_vals_.insert(next_a);
ir::value* next_b = b->get_incoming_value(1);
while (auto* next_b_phi = dynamic_cast<ir::phi_node*>(next_b)) {
prefetched_vals_.insert(next_b_phi->get_incoming_value(0));
next_b = next_b_phi->get_incoming_value(1);
}
prefetched_vals_.insert(next_b);
}
// move loads to the beginning of the loop
if (tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80) {
for (ir::function *fn : mod.get_function_list())
for (ir::basic_block *bb : fn->blocks()) {
// only apply to loop body
if (bb->get_predecessors().size() != 2 || bb->get_predecessors()[1] != bb)
continue;
// record loads (& dependency) to move
std::vector<ir::instruction*> loads;
// record original inst order
std::map<ir::instruction*, size_t> idx_map;
size_t idx = 0;
for (ir::instruction *inst : bb->get_inst_list()) {
if (auto *i = dynamic_cast<ir::masked_load_inst*>(inst))
recursive_defs(i, bb, loads);
idx_map[inst] = idx;
idx++;
}
// remove duplicates & keep the original input order
std::sort(loads.begin(), loads.end());
loads.erase(std::unique(loads.begin(), loads.end()), loads.end());
std::sort(loads.begin(), loads.end(), [&idx_map](ir::instruction *a, ir::instruction *b) {
return idx_map[a] < idx_map[b];
});
builder.set_insert_point(bb->get_first_non_phi());
auto& inst_list = bb->get_inst_list();
for (ir::instruction *i : loads){
auto it = std::find(inst_list.begin(), inst_list.end(), i);
// make sure we don't invalidate insert point
// in case instruction already at the top
if(it == builder.get_insert_point())
continue;
bb->erase(i);
builder.insert(i);
}
}
}
}
} // namespace triton::codegen::transform

View File

@@ -1,51 +0,0 @@
#include <iostream>
#include <algorithm>
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/codegen/transform/reorder.h"
namespace triton {
namespace codegen{
namespace transform{
void reorder::run(ir::module& mod){
// ir::builder &builder = mod.get_builder();
// std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
// for(ir::function *fn: mod.get_function_list())
// for(ir::basic_block *block: fn->blocks())
// for(ir::instruction* i: block->get_inst_list()){
// if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
// ir::value* _ptr = ld->get_pointer_operand();
// ir::value* _msk = ld->get_mask_operand();
// ir::value* _val = ld->get_false_value_operand();
// auto ptr = std::find(block->begin(), block->end(), _ptr);
// auto msk = std::find(block->begin(), block->end(), _msk);
// auto val = std::find(block->begin(), block->end(), _val);
// if(ptr == block->end() || msk == block->end() || val == block->end())
// continue;
// auto it = std::find(block->begin(), block->end(), i);
// int dist_ptr = std::distance(ptr, it);
// int dist_msk = std::distance(msk, it);
// int dist_val = std::distance(val, it);
// if(dist_ptr < dist_msk && dist_ptr < dist_val)
// builder.set_insert_point(++ptr);
// if(dist_msk < dist_ptr && dist_msk < dist_val)
// builder.set_insert_point(++msk);
// if(dist_val < dist_ptr && dist_val < dist_msk)
// builder.set_insert_point(++val);
// ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
// to_replace.push_back(std::make_pair(ld, new_ld));
// }
// }
// for(auto& x: to_replace)
// x.first->replace_all_uses_with(x.second);
}
}
}
}