History prior to this date belonged to the now deprecated ISAAC project, and was deleted to save space
This commit is contained in:
514
lib/codegen/analysis/align.cc
Normal file
514
lib/codegen/analysis/align.cc
Normal file
@@ -0,0 +1,514 @@
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
// Function for extended Euclidean Algorithm
|
||||
int gcd_impl(int a, int b, int *x, int *y)
|
||||
{
|
||||
// Base Case
|
||||
if (a == 0)
|
||||
{
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*y = x1;
|
||||
|
||||
return gcd;
|
||||
}
|
||||
|
||||
int gcd(int a, int b) {
|
||||
int x, y;
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
|
||||
inline int lcm(int a, int b) {
|
||||
return (a * b) / gcd(a, b);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
||||
return map[i] = value;
|
||||
}
|
||||
|
||||
/*
|
||||
* is constant
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::get_shapes(ir::value *v) {
|
||||
ir::type *ty = v->get_type();
|
||||
if(ty->is_tile_ty())
|
||||
return ty->get_tile_shapes();
|
||||
else
|
||||
return {1};
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<cst_info> result(shapes.size(), cst_info{1, 0});
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto it = is_constant_.find(inc);
|
||||
if(it != is_constant_.end())
|
||||
result = it->second;
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto cst = populate_is_constant(inc);
|
||||
for(size_t d = 0; d < cst.size(); d++)
|
||||
result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
ir::value* op = x->get_operand(0);
|
||||
std::vector<cst_info> result;
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(auto d: shapes)
|
||||
result.push_back(cst_info{d, op_cst[0].value});
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < x_shapes.size(); d ++){
|
||||
cst_info ax ;
|
||||
if(x_shapes[d] == 1)
|
||||
ax = {1, op_cst[current].value};
|
||||
else if(!is_skewed
|
||||
&& x_shapes[d] == op_shapes[current])
|
||||
ax = {x_shapes[d], op_cst[current++].value};
|
||||
else {
|
||||
is_skewed = true;
|
||||
ax = {x_shapes[d], 0};
|
||||
}
|
||||
result.push_back(ax);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
if(op_shapes[d] == 1)
|
||||
result.push_back(cst_info{x_shapes[d], op_cst[d].value});
|
||||
else
|
||||
result.push_back(op_cst[d]);
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
auto max_contiguous = populate_max_contiguous(lhs_op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
cst_info ax;
|
||||
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
|
||||
// todo might not be entirely true
|
||||
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
|
||||
ax = {num_constants, 0};
|
||||
}
|
||||
else
|
||||
ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0};
|
||||
result.push_back(ax);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
std::vector<cst_info> result;
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0});
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) {
|
||||
auto shapes = get_shapes(v);
|
||||
std::vector<cst_info> result(shapes.size(), {1, 0});
|
||||
return add_to_cache(v, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
if(is_constant_.find(v) != is_constant_.end())
|
||||
return is_constant_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
|
||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_is_constant_phi(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_is_constant_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_is_constant_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_is_constant_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_is_constant_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_is_constant_gep(x);
|
||||
return populate_is_constant_default(v);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* max contiguous
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto it = max_contiguous_.find(inc);
|
||||
if(it != max_contiguous_.end())
|
||||
result = it->second;
|
||||
}
|
||||
add_to_cache(x, result, max_contiguous_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto contiguous = populate_max_contiguous(inc);
|
||||
for(size_t d = 0; d < result.size(); d++)
|
||||
result[d] = std::min(result[d], contiguous[d]);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
result.push_back({1});
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < shapes.size(); d ++){
|
||||
if(shapes[d] == 1)
|
||||
result.push_back(1);
|
||||
else if(!is_skewed
|
||||
&& shapes[d] == op_shapes[current])
|
||||
result.push_back(op_mc[current++]);
|
||||
else {
|
||||
is_skewed = true;
|
||||
result.push_back(1);
|
||||
}
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
if(op_shapes[d] == 1)
|
||||
result.push_back(1);
|
||||
else
|
||||
result.push_back(op_mc[d]);
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
auto lhs_cst_info = populate_is_constant(lhs);
|
||||
auto rhs_cst_info = populate_is_constant(rhs);
|
||||
std::vector<unsigned> result;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
unsigned value = 1;
|
||||
if(x->is_int_rem() && rhs_cst_info[d].value > 0)
|
||||
value = std::min(lhs_max_contiguous[d], rhs_cst_info[d].value);
|
||||
if(x->is_int_mult()){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(rhs_cst_info[d].value == 1)
|
||||
lvalue = lhs_max_contiguous[d];
|
||||
if(lhs_cst_info[d].value == 1)
|
||||
rvalue = rhs_max_contiguous[d];
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
if(x->is_int_add_sub()){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(lhs_cst_info[d].num_cst > 0)
|
||||
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
|
||||
if(rhs_cst_info[d].num_cst > 0)
|
||||
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
result.push_back(value);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
auto lhs_cst_info = populate_is_constant(lhs);
|
||||
auto rhs_cst_info = populate_is_constant(rhs);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(lhs_cst_info[d].num_cst)
|
||||
lvalue = rhs_max_contiguous[d];
|
||||
if(rhs_cst_info[d].num_cst)
|
||||
rvalue = lhs_max_contiguous[d];
|
||||
result[d] = std::max(lvalue, rvalue);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
||||
if(!v->get_type()->is_tile_ty())
|
||||
return add_to_cache(v, {1}, max_contiguous_);
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||
return max_contiguous_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_max_contiguous_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_max_contiguous_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_max_contiguous_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_max_contiguous_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_max_contiguous_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_max_contiguous_phi(x);
|
||||
return populate_max_contiguous_default(v);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* starting multiple
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){
|
||||
auto shapes = get_shapes(x);
|
||||
auto op = populate_starting_multiple(x->get_operand(0));
|
||||
std::vector<unsigned> result(shapes.size(), op[0]);
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){
|
||||
auto op = populate_starting_multiple(x->get_operand(0));
|
||||
auto op_shapes = get_shapes(x->get_operand(0));
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < shapes.size(); d ++){
|
||||
if(shapes[d] == 1)
|
||||
result[d] = 1;
|
||||
else if(!is_skewed
|
||||
&& shapes[d] == op_shapes[current])
|
||||
result[d] = op[current++];
|
||||
else {
|
||||
is_skewed = true;
|
||||
result[d] = 1;
|
||||
}
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){
|
||||
auto lhs = populate_starting_multiple(x->get_operand(0));
|
||||
auto rhs = populate_starting_multiple(x->get_operand(1));
|
||||
std::vector<unsigned> result(lhs.size(), 1);
|
||||
for(size_t d = 0; d < lhs.size(); d++){
|
||||
if(x->is_int_mult())
|
||||
result[d] = lhs[d] * rhs[d];
|
||||
if(x->is_int_add_sub())
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
if(x->is_int_div())
|
||||
result[d] = std::max<unsigned>(lhs[d] / rhs[d], 1);
|
||||
if(x->is_int_rem() && rhs[d] > 1)
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
if(x->is_shl())
|
||||
result[d] = lhs[d] << rhs[d];
|
||||
if(x->is_shr())
|
||||
result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){
|
||||
auto lhs = populate_starting_multiple(x->get_operand(0));
|
||||
auto rhs = populate_starting_multiple(x->get_operand(1));
|
||||
std::vector<unsigned> result(lhs.size(), 1);
|
||||
for(size_t d = 0; d < lhs.size(); d++)
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
|
||||
auto shape = get_shapes(x);
|
||||
std::vector<unsigned> result(shape.size(), 1);
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
if(starting_multiple_.find(inc) != starting_multiple_.end())
|
||||
result = starting_multiple_.at(inc);
|
||||
}
|
||||
add_to_cache(x, result, starting_multiple_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto sm = populate_starting_multiple(inc);
|
||||
for(size_t d = 0; d < result.size(); d++)
|
||||
result[d] = gcd(result[d], sm[d]);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
ir::type* ty = v->get_type();
|
||||
if(ty->is_tile_ty()) {
|
||||
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(multiple_of > 0)
|
||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||
for(auto attr: attributes){
|
||||
if(attr.get_kind() == ir::multiple_of){
|
||||
return add_to_cache(x, {attr.get_value()}, starting_multiple_);
|
||||
}
|
||||
if(attr.get_kind() == ir::aligned){
|
||||
ir::type* ty = x->get_type()->get_pointer_element_ty();
|
||||
int nbits = ty->get_primitive_size_in_bits();
|
||||
int nbytes = nbits / 8;
|
||||
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
|
||||
}
|
||||
}
|
||||
}
|
||||
return add_to_cache(v, {1}, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_starting_multiple_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range_dyn*>(v))
|
||||
return add_to_cache(x, {128}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_range()->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_starting_multiple_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_starting_multiple_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_starting_multiple_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_starting_multiple_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_starting_multiple_phi(x);
|
||||
return populate_starting_multiple_default(v);
|
||||
}
|
||||
|
||||
|
||||
unsigned align::get(ir::value *v, unsigned ax) const {
|
||||
unsigned starting_multiple = starting_multiple_.at(v)[ax];
|
||||
unsigned max_contiguous = max_contiguous_.at(v)[ax];
|
||||
return std::min(starting_multiple, max_contiguous);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::contiguous(ir::value* v) const {
|
||||
return max_contiguous_.at(v);
|
||||
}
|
||||
|
||||
|
||||
void align::populate(ir::value *v) {
|
||||
populate_is_constant(v);
|
||||
populate_starting_multiple(v);
|
||||
populate_max_contiguous(v);
|
||||
|
||||
}
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
107
lib/codegen/analysis/allocation.cc
Normal file
107
lib/codegen/analysis/allocation.cc
Normal file
@@ -0,0 +1,107 @@
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
void allocation::run(ir::module &mod) {
|
||||
using std::max;
|
||||
using std::min;
|
||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||
|
||||
std::vector<shared_layout*> I;
|
||||
for(auto x: liveness_->get())
|
||||
I.push_back(x.first);
|
||||
std::vector<shared_layout*> J = I;
|
||||
|
||||
triples_map_type H;
|
||||
H.insert({0, segment{0, INT_MAX}});
|
||||
|
||||
std::vector<shared_layout*> V;
|
||||
std::map<shared_layout*, unsigned> starts;
|
||||
while(!J.empty()){
|
||||
auto h_it = H.begin();
|
||||
unsigned w = h_it->first;
|
||||
segment xh = h_it->second;
|
||||
H.erase(h_it);
|
||||
auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){
|
||||
segment xj = liveness_->get(JJ);
|
||||
bool res = xj.intersect(xh);
|
||||
for(auto val: H)
|
||||
res = res && !val.second.intersect(xj);
|
||||
return res;
|
||||
});
|
||||
if(j_it != J.end()){
|
||||
unsigned size = (*j_it)->get_size();
|
||||
segment xj = liveness_->get(*j_it);
|
||||
starts[*j_it] = w;
|
||||
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
|
||||
if(xh.start < xj.start)
|
||||
H.insert({w, segment{xh.start, xj.end}});
|
||||
if(xj.end < xh.end)
|
||||
H.insert({w, segment{xj.start, xh.end}});
|
||||
V.push_back(*j_it);
|
||||
J.erase(j_it);
|
||||
}
|
||||
}
|
||||
|
||||
// Build interference graph
|
||||
std::map<shared_layout*, std::set<shared_layout*>> interferences;
|
||||
for(shared_layout* x: V)
|
||||
for(shared_layout* y: V){
|
||||
if(x == y)
|
||||
continue;
|
||||
unsigned X0 = starts[x], Y0 = starts[y];
|
||||
unsigned NX = x->get_size();
|
||||
unsigned NY = y->get_size();
|
||||
segment XS = {X0, X0 + NX};
|
||||
segment YS = {Y0, Y0 + NY};
|
||||
if(liveness_->get(x).intersect(liveness_->get(y))
|
||||
&& XS.intersect(YS))
|
||||
interferences[x].insert(y);
|
||||
}
|
||||
|
||||
// Initialize colors
|
||||
std::map<shared_layout*, int> colors;
|
||||
for(shared_layout* X: V)
|
||||
colors[X] = (X==V[0])?0:-1;
|
||||
|
||||
|
||||
// First-fit graph coloring
|
||||
std::vector<bool> available(V.size());
|
||||
for(shared_layout* x: V){
|
||||
// Non-neighboring colors are available
|
||||
std::fill(available.begin(), available.end(), true);
|
||||
for(shared_layout* Y: interferences[x]){
|
||||
int color = colors[Y];
|
||||
if(color >= 0)
|
||||
available[color] = false;
|
||||
}
|
||||
// Assigns first available color
|
||||
auto It = std::find(available.begin(), available.end(), true);
|
||||
colors[x] = std::distance(available.begin(), It);
|
||||
}
|
||||
|
||||
// Finalize allocation
|
||||
for(shared_layout* x: V){
|
||||
unsigned Adj = 0;
|
||||
for(shared_layout* y: interferences[x])
|
||||
Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
|
||||
offsets_[x] = starts[x] + colors[x] * Adj;
|
||||
}
|
||||
|
||||
// Save maximum size of induced memory space
|
||||
allocated_size_ = 0;
|
||||
for(shared_layout* x: V)
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
147
lib/codegen/analysis/axes.cc
Normal file
147
lib/codegen/analysis/axes.cc
Normal file
@@ -0,0 +1,147 @@
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
axes::axes() {}
|
||||
|
||||
void axes::update_graph_reduce(ir::instruction *i) {
|
||||
auto* red = static_cast<ir::reduce_inst*>(i);
|
||||
unsigned axis = red->get_axis();
|
||||
ir::value *arg = red->get_operand(0);
|
||||
auto in_shapes = arg->get_type()->get_tile_shapes();
|
||||
unsigned current = 0;
|
||||
for(unsigned d = 0; d < in_shapes.size(); d++){
|
||||
if(d == axis)
|
||||
continue;
|
||||
graph_.add_edge({i, current++}, {arg, d});
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_reshape(ir::instruction *i) {
|
||||
auto* reshape = static_cast<ir::reshape_inst*>(i);
|
||||
// operands
|
||||
ir::value *op = reshape->get_operand(0);
|
||||
// shapes
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto res_shapes = reshape->get_type()->get_tile_shapes();
|
||||
// construct edges
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(unsigned d = 0; d < res_shapes.size(); d ++){
|
||||
bool same_shape = res_shapes[d] == op_shapes[current];
|
||||
// either add edge between axis or just add a node in the graph
|
||||
if(!is_skewed && same_shape)
|
||||
graph_.add_edge({i, d}, {op, current++});
|
||||
else
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
// reshaping is skewed
|
||||
if(res_shapes[d] > 1 && !same_shape)
|
||||
is_skewed = true;
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_trans(ir::instruction *i) {
|
||||
auto *trans = static_cast<ir::trans_inst*>(i);
|
||||
ir::value *op = trans->get_operand(0);
|
||||
auto perm = trans->get_perm();
|
||||
// add edge between axis perm[d] and axis d
|
||||
for(unsigned d = 0; d < perm.size(); d++)
|
||||
graph_.add_edge({i, perm[d]}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
|
||||
auto shapes = broadcast->get_type()->get_tile_shapes();
|
||||
ir::value *op = broadcast->get_operand(0);
|
||||
ir::type *op_ty = op->get_type();
|
||||
const auto& op_shapes = op_ty->get_tile_shapes();
|
||||
// add edge between non-broadcast axes
|
||||
for(unsigned d = 0; d < shapes.size(); d ++)
|
||||
if(op_shapes[d] == shapes[d])
|
||||
graph_.add_edge({i, d}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_dot(ir::instruction *i) {
|
||||
auto *dot = static_cast<ir::dot_inst*>(i);
|
||||
auto shapes = dot->get_type()->get_tile_shapes();
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
ir::value *D = dot->get_operand(2);
|
||||
// add edges between result and accumulator
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
ir::value *op = i->get_operand(0);
|
||||
if(!op->get_type()->is_tile_ty())
|
||||
return;
|
||||
auto rank = op->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
if(!i->get_type()->is_void_ty())
|
||||
graph_.add_edge({i, d}, {opx, d});
|
||||
graph_.add_edge({opx, d}, {opy, d});
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_no_edge(ir::instruction *i) {
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
return;
|
||||
auto rank = i->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
}
|
||||
|
||||
void axes::update_graph(ir::instruction *i) {
|
||||
switch (i->get_id()) {
|
||||
case ir::INST_REDUCE: return update_graph_reduce(i);
|
||||
case ir::INST_RESHAPE: return update_graph_reshape(i);
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);;
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
|
||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
|
||||
default: return update_graph_elementwise(i);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
int axes::get(ir::value *value, unsigned dim) {
|
||||
return axes_.at({value, dim});
|
||||
}
|
||||
|
||||
std::vector<int> axes::get(ir::value *value) {
|
||||
std::vector<int> result;
|
||||
for(size_t d = 0; d < value->get_type()->get_tile_rank(); d++)
|
||||
result.push_back(this->get(value, d));
|
||||
return result;
|
||||
}
|
||||
|
||||
void axes::run(ir::module &mod) {
|
||||
// make graph
|
||||
graph_.clear();
|
||||
ir::for_each_instruction(mod, [this](ir::instruction *x) {
|
||||
update_graph(x);
|
||||
});
|
||||
// find connected components
|
||||
graph_.connected_components(nullptr, &axes_);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
442
lib/codegen/analysis/layout.cc
Normal file
442
lib/codegen/analysis/layout.cc
Normal file
@@ -0,0 +1,442 @@
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
/* -------------------------------- *
|
||||
* Helper Functions *
|
||||
* -------------------------------- */
|
||||
|
||||
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
||||
return std::min(std::max(x, lo), hi);
|
||||
}
|
||||
|
||||
inline bool is_hmma_c(ir::value *v){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
ir::type *a_ty = a->get_type();
|
||||
ir::value *b = x->get_operand(1);
|
||||
ir::type *b_ty = b->get_type();
|
||||
result = a_ty->get_scalar_ty()->is_half_ty() &&
|
||||
b_ty->get_scalar_ty()->is_half_ty();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
if(i && i->get_pointer_operand() == v)
|
||||
result.insert(v);
|
||||
}
|
||||
}
|
||||
|
||||
inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && i->get_operand(n) == v)
|
||||
result = v;
|
||||
}
|
||||
}
|
||||
|
||||
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
||||
result = v;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
return true;
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
|
||||
bool result = true;
|
||||
for(ir::value *op: phi->ops())
|
||||
result = result && is_trans(op);
|
||||
return result;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Layout Visitor *
|
||||
* -------------------------------- */
|
||||
|
||||
void layout_visitor::visit_layout(data_layout *layout) {
|
||||
layout->accept(this);
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Base Data Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
data_layout::data_layout(id_t id,
|
||||
const std::vector<int> &axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
|
||||
// io pointer
|
||||
std::set<ir::value*> ptr;
|
||||
for(ir::value* v: values_)
|
||||
extract_io_use(v, ptr);
|
||||
order_.resize(axes_.size());
|
||||
std::iota(order_.begin(), order_.end(), 0);
|
||||
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
|
||||
return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank();
|
||||
});
|
||||
if(*largest){
|
||||
auto max_contiguous = align->contiguous(*largest);
|
||||
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
|
||||
return max_contiguous[a] > max_contiguous[b];
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
size_t data_layout::find_axis(int to_find) const {
|
||||
auto it = std::find(axes_.begin(), axes_.end(), to_find);
|
||||
return std::distance(axes_.begin(), it);
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* MMA Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
mma884_layout::mma884_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
fpw_ = {1, 1, 1};
|
||||
std::vector<int> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw_;
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||
}while(fpw_nm1 != fpw_);
|
||||
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
wpt_ = {1, 1, 1};
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
|
||||
}while(wpt_nm1 != wpt_);
|
||||
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shape.size(); d++)
|
||||
effective_num_warps *= wpt_[d];
|
||||
if(num_warps != effective_num_warps)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Scanline Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
scanline_layout::scanline_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){
|
||||
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
|
||||
unsigned num_threads = num_warps * 32;
|
||||
nts_.resize(shape_.size());
|
||||
mts_.resize(shape_.size());
|
||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
||||
|
||||
ir::value *ptr = nullptr;
|
||||
for(ir::value *v: values)
|
||||
for(ir::user *usr: v->get_users())
|
||||
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
|
||||
ptr = st->get_pointer_operand();
|
||||
|
||||
unsigned i = order_[0];
|
||||
int contiguous = 4;
|
||||
if(ptr)
|
||||
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
|
||||
|
||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
size /= shape_[i];
|
||||
num_threads /= mts_[i];
|
||||
if(is_dot)
|
||||
nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
|
||||
for(size_t d = 1; d < shape_.size(); d++){
|
||||
i = order_[d];
|
||||
if(d > 1 || !is_dot)
|
||||
nts_[i] = 1;
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
num_threads = num_threads / mts_[i];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shape_.size(); d++)
|
||||
effective_num_threads *= mts_[d];
|
||||
|
||||
if(num_warps * 32 != effective_num_threads)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Shared Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
if(phi->get_parent() != terminator->get_parent())
|
||||
return false;
|
||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||
return br->get_true_dest() == phi->get_parent()
|
||||
|| br->get_false_dest() == phi->get_parent();
|
||||
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
|
||||
return false;
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
|
||||
void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(v);
|
||||
if(!phi || phi->get_num_incoming() != 2)
|
||||
return;
|
||||
ir::basic_block *block_0 = phi->get_incoming_block(0);
|
||||
ir::basic_block *block_1 = phi->get_incoming_block(1);
|
||||
ir::instruction *terminator_0 = block_0->get_inst_list().back();
|
||||
ir::instruction *terminator_1 = block_1->get_inst_list().back();
|
||||
bool is_latch_0 = is_loop_latch(phi, terminator_0);
|
||||
bool is_latch_1 = is_loop_latch(phi, terminator_1);
|
||||
ir::value *value_0 = phi->get_incoming_value(0);
|
||||
ir::value *value_1 = phi->get_incoming_value(1);
|
||||
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
|
||||
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
|
||||
if(!i_0 || !i_1 ||
|
||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) ||
|
||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) )
|
||||
return;
|
||||
if(is_latch_1)
|
||||
res.reset(new double_buffer_info_t{value_0, value_1, phi});
|
||||
if(is_latch_0)
|
||||
res.reset(new double_buffer_info_t{value_1, value_0, phi});
|
||||
}
|
||||
|
||||
|
||||
shared_layout::shared_layout(const data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
ir::type *ty,
|
||||
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
|
||||
|
||||
size_ = 0;
|
||||
|
||||
// double-buffering
|
||||
for(ir::value *v: values)
|
||||
extract_double_bufferable(v, double_buffer_);
|
||||
|
||||
// order
|
||||
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
|
||||
order_ = arg_order;
|
||||
|
||||
ir::value* dot_a = nullptr;
|
||||
ir::value* dot_b = nullptr;
|
||||
ir::value* hmma_dot_a = nullptr;
|
||||
ir::value* hmma_dot_b = nullptr;
|
||||
for(ir::value* v: values){
|
||||
extract_dot_use(v, dot_a, 0);
|
||||
extract_dot_use(v, dot_b, 1);
|
||||
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
||||
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
||||
}
|
||||
|
||||
|
||||
// non-mma ordering
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
for(size_t s = 2; s < get_rank(); s++){
|
||||
col.push_back(s);
|
||||
row.push_back(s);
|
||||
}
|
||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
||||
if(is_nonhmma_dot_a)
|
||||
order_ = is_trans(dot_a) ? row : col;
|
||||
else if(is_nonhmma_dot_b)
|
||||
order_ = is_trans(dot_b) ? col : row;
|
||||
|
||||
// padding
|
||||
size_t pad = 0;
|
||||
if(hmma_dot_a){
|
||||
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
|
||||
pad = 24 - shape_[row ? 0 : 1] % 32;
|
||||
}
|
||||
else if(hmma_dot_b){
|
||||
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
|
||||
pad = 24 - shape_[row ? 1 : 0] % 32;
|
||||
}
|
||||
else if(order_ != arg_order) {
|
||||
pad = 4;
|
||||
}
|
||||
shape_[order_[0]] += pad;
|
||||
|
||||
// size
|
||||
size_ = ty_->get_primitive_size_in_bits() / 8;
|
||||
for(auto s: shape_)
|
||||
size_ *= s;
|
||||
if(double_buffer_)
|
||||
size_ *= 2;
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* ---- Layouts Inference Pass ---- *
|
||||
* -------------------------------- */
|
||||
|
||||
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps)
|
||||
: axes_(axes), align_(align), num_warps_(num_warps) { }
|
||||
|
||||
|
||||
void layouts::connect(ir::value *x, ir::value *y) {
|
||||
if(x == y)
|
||||
return;
|
||||
if(!x->get_type()->is_tile_ty())
|
||||
return;
|
||||
if(!y->get_type()->is_tile_ty())
|
||||
return;
|
||||
std::vector<int> x_axes = axes_->get(x);
|
||||
std::vector<int> y_axes = axes_->get(y);
|
||||
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
|
||||
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
|
||||
std::set<int> common;
|
||||
std::set_intersection(sx_axes.begin(), sx_axes.end(),
|
||||
sy_axes.begin(), sy_axes.end(),
|
||||
std::inserter(common, common.begin()));
|
||||
graph_.add_edge(x, x);
|
||||
graph_.add_edge(y, y);
|
||||
if(!common.empty())
|
||||
graph_.add_edge(x, y);
|
||||
}
|
||||
|
||||
void layouts::make_graph(ir::instruction *i) {
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
connect(i, opx);
|
||||
connect(opx, opy);
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||
auto cmp = [](ir::value* x, ir::value *y) {
|
||||
return x->get_type()->get_tile_ranks1() <
|
||||
y->get_type()->get_tile_ranks1();
|
||||
};
|
||||
std::vector<ir::value*> lvalue = values;
|
||||
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
|
||||
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
|
||||
const auto& axes = axes_->get(largest);
|
||||
const auto& shapes = largest->get_type()->get_tile_shapes();
|
||||
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
|
||||
return dynamic_cast<ir::copy_to_shared_inst*>(v);
|
||||
});
|
||||
// type
|
||||
if(it_hmma_c != values.end())
|
||||
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
|
||||
else if(it_cts != values.end()){
|
||||
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
||||
}
|
||||
else
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_);
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
// make graph
|
||||
graph_.clear();
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||
make_graph(i);
|
||||
});
|
||||
|
||||
// connected components
|
||||
graph_.connected_components(&values_, &groups_);
|
||||
|
||||
// create layouts
|
||||
for(const auto& x: values_)
|
||||
create(x.first, x.second);
|
||||
|
||||
// create temporaries
|
||||
size_t id = values_.size();
|
||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||
id++;
|
||||
ir::value *arg = red->get_operand(0);
|
||||
unsigned axis = red->get_axis();
|
||||
// shape
|
||||
auto shapes = arg->get_type()->get_tile_shapes();
|
||||
unsigned shape_ax = shapes[axis];
|
||||
scanline_layout *layout = get(arg)->to_scanline();
|
||||
unsigned per_thread = layout->nts(axis);
|
||||
unsigned depth = shape_ax / per_thread;
|
||||
shapes[axis] = depth;
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[red] = id;
|
||||
}
|
||||
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
|
||||
ir::value *val = recoalasce->get_operand(0);
|
||||
mma884_layout* in_layout = get(val)->to_mma884();
|
||||
scanline_layout* out_layout = get(i)->to_scanline();
|
||||
if(!in_layout || !out_layout)
|
||||
return;
|
||||
id++;
|
||||
ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes();
|
||||
ir::type::tile_shapes_t shape(in_shape.size());
|
||||
size_t ld = out_layout->get_order(0);
|
||||
shape[ld] = in_shape[ld];
|
||||
for(size_t k = 0; k < in_shape.size(); k++)
|
||||
if(k != ld)
|
||||
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[recoalasce] = id;
|
||||
}
|
||||
if(auto *atom = dynamic_cast<ir::atomic_cas_inst*>(i)){
|
||||
id++;
|
||||
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[atom] = id;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
57
lib/codegen/analysis/liveness.cc
Normal file
57
lib/codegen/analysis/liveness.cc
Normal file
@@ -0,0 +1,57 @@
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
void liveness::run(ir::module &mod) {
|
||||
intervals_.clear();
|
||||
|
||||
// Assigns index to each instruction
|
||||
std::map<ir::value*, slot_index> indices;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
slot_index index = 0;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()){
|
||||
index += 1;
|
||||
indices.insert({instr, index});
|
||||
}
|
||||
}
|
||||
|
||||
// create live intervals
|
||||
for(auto &x: layouts_->get_all()) {
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(!layout)
|
||||
continue;
|
||||
// users
|
||||
std::set<ir::user*> users;
|
||||
for(ir::value *v: layout->get_values()){
|
||||
for(ir::user *u: v->get_users())
|
||||
users.insert(u);
|
||||
}
|
||||
// compute intervals
|
||||
unsigned start = INT32_MAX;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(indices.find(v) != indices.end())
|
||||
start = std::min(start, indices.at(v));
|
||||
unsigned end = 0;
|
||||
for(ir::user *u: users)
|
||||
if(indices.find(u) != indices.end())
|
||||
end = std::max(end, indices.at(u));
|
||||
intervals_[layout] = segment{start, end};
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
1308
lib/codegen/selection/generator.cc
Normal file
1308
lib/codegen/selection/generator.cc
Normal file
File diff suppressed because it is too large
Load Diff
326
lib/codegen/selection/machine_layout.cc
Normal file
326
lib/codegen/selection/machine_layout.cc
Normal file
@@ -0,0 +1,326 @@
|
||||
#include <numeric>
|
||||
#include "triton/codegen/selection/machine_layout.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
|
||||
std::vector<Type*> param_tys;
|
||||
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
|
||||
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
|
||||
return FunctionType::get(return_ty, param_tys, false);
|
||||
}
|
||||
// pointer
|
||||
if(ty->is_pointer_ty()){
|
||||
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
|
||||
unsigned addr_space = ty->get_pointer_address_space();
|
||||
return PointerType::get(elt_ty, addr_space);
|
||||
}
|
||||
// integer
|
||||
if(ty->is_integer_ty()){
|
||||
unsigned bitwidth = ty->get_integer_bitwidth();
|
||||
return IntegerType::get(ctx, bitwidth);
|
||||
}
|
||||
// primitive types
|
||||
switch(ty->get_type_id()){
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
|
||||
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
|
||||
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
|
||||
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
|
||||
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
|
||||
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
|
||||
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
|
||||
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
|
||||
default: break;
|
||||
}
|
||||
// unknown type
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
// Grid construction
|
||||
inline std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
|
||||
size_t dim = shapes.size();
|
||||
std::vector<Value*> result(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = builder.getInt32(shapes[order[k]]);
|
||||
Value *rem = builder.CreateURem(trailing, dim_k);
|
||||
trailing = builder.CreateUDiv(trailing, dim_k);
|
||||
result[order[k]] = rem;
|
||||
}
|
||||
result[order[dim - 1]] = trailing;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline int32_t ceil(int32_t num, int32_t div){
|
||||
return (num + div - 1)/div;
|
||||
}
|
||||
|
||||
|
||||
|
||||
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
|
||||
Value *&sh_mem_ptr, analysis::shared_layout *layout,
|
||||
std::map<ir::value *, Value *>& vmap,
|
||||
std::map<ir::value *, tile *>& tmap)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
|
||||
|
||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
|
||||
// double-buffered
|
||||
if(layout_->get_double_buffer()) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
auto info = *layout_->get_double_buffer();
|
||||
ir::phi_node *phi = info.phi;
|
||||
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
|
||||
if(parent->empty())
|
||||
builder_->SetInsertPoint(parent);
|
||||
else
|
||||
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
|
||||
// create pointers
|
||||
ptr_ = builder_->CreatePHI(ptr_ty, 2);
|
||||
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
|
||||
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
|
||||
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
|
||||
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
|
||||
builder_->SetInsertPoint(current);
|
||||
}
|
||||
else{
|
||||
size_t offset = alloc_->offset(layout_);
|
||||
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
|
||||
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
tile* machine_shared_layout::create(ir::value *v) {
|
||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||
auto double_buffer = layout_->get_double_buffer();
|
||||
// offset
|
||||
Value *offset = nullptr;
|
||||
if(double_buffer && v == double_buffer->phi)
|
||||
offset = offset_;
|
||||
// base pointer
|
||||
Value *ptr = ptr_;
|
||||
if(double_buffer && v == double_buffer->latch)
|
||||
ptr = next_ptr_;
|
||||
else if(double_buffer && v == double_buffer->first)
|
||||
ptr = pre_ptr_;
|
||||
// create tile
|
||||
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
|
||||
}
|
||||
|
||||
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::data_layout *layout)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
|
||||
|
||||
}
|
||||
|
||||
tile *machine_distributed_layout::create(ir::value *v) {
|
||||
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
size_t rank = shapes.size();
|
||||
std::vector<distributed_axis> axes(rank);
|
||||
std::vector<int> order(rank);
|
||||
// compute axes
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
axes[d].contiguous = 1;
|
||||
axes[d].values = {builder_->getInt32(0)};
|
||||
}
|
||||
}
|
||||
// compute order
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
auto cmp = [&](int x, int y) {
|
||||
unsigned axx = a_axes_->get(v, x);
|
||||
unsigned axy = a_axes_->get(v, y);
|
||||
size_t posx = layout_->find_axis(axx);
|
||||
size_t posy = layout_->find_axis(axy);
|
||||
if(posx < rank && posy < rank)
|
||||
return layout_->get_order(posx) < layout_->get_order(posy);
|
||||
return false;
|
||||
};
|
||||
std::sort(order.begin(), order.end(), cmp);
|
||||
|
||||
return new distributed_tile(ty, shapes, order, axes, *builder_);
|
||||
}
|
||||
|
||||
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
|
||||
target *tgt, analysis::axes *a_axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::mma884_layout* layout)
|
||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
const auto& shape = layout->get_shape();
|
||||
if(shape.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
bool is_batched = shape.size() >= 3;
|
||||
|
||||
Value *_1 = builder_->getInt32(1);
|
||||
Value *_2 = builder_->getInt32(2);
|
||||
Value *_3 = builder_->getInt32(3);
|
||||
Value *_4 = builder_->getInt32(4);
|
||||
Value *_16 = builder_->getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = layout->fpw(0);
|
||||
unsigned fpw_1 = layout->fpw(1);
|
||||
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = layout->wpt(0);
|
||||
unsigned wpt_1 = layout->wpt(1);
|
||||
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
|
||||
// mma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
||||
// mma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||
// number of repetition
|
||||
unsigned num_rep_0 = shape[0] / hmma_bts_0;
|
||||
unsigned num_rep_1 = shape[1] / hmma_bts_1;
|
||||
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
// number of packs (interleaving)
|
||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
||||
|
||||
/* intra warp offset */
|
||||
// offset of quad in pair
|
||||
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_0 * pack_size_0_));
|
||||
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_1 * pack_size_1_));
|
||||
|
||||
// Quad pair id
|
||||
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
|
||||
// Quad pair offset
|
||||
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
|
||||
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
|
||||
|
||||
/* inter warp offset */
|
||||
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
|
||||
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
|
||||
|
||||
/* offsets */
|
||||
// a offset
|
||||
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
|
||||
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
// b offsets
|
||||
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
|
||||
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
|
||||
// c offsets
|
||||
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
|
||||
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
|
||||
builder_->CreateAdd(warp_offset_j, pair_b_off));
|
||||
|
||||
/* indices */
|
||||
// i indices
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
||||
}
|
||||
// j indices
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
||||
}
|
||||
// z indices
|
||||
std::vector<Value*> idx_z;
|
||||
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
||||
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
|
||||
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
|
||||
target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
|
||||
analysis::scanline_layout* layout)
|
||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
auto order = layout->get_order();
|
||||
const auto& shape = layout->get_shape();
|
||||
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
|
||||
// Delinearize
|
||||
size_t dim = shape.size();
|
||||
std::vector<Value*> thread_id(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
|
||||
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
|
||||
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
|
||||
thread_id[order[k]] = rem;
|
||||
}
|
||||
thread_id[order[dim - 1]] = full_thread_id;
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
int nts = layout->nts(k);
|
||||
int mts = layout->mts(k);
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = builder_->getInt32(nts);
|
||||
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
|
||||
unsigned per_block = nts * mts;
|
||||
unsigned per_thread = nts * shape[k] / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts * per_block + n % nts;
|
||||
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
214
lib/codegen/selection/machine_value.cc
Normal file
214
lib/codegen/selection/machine_value.cc
Normal file
@@ -0,0 +1,214 @@
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
/* Distributed Tile */
|
||||
void distributed_tile::init_indices() {
|
||||
std::vector<size_t> id(axes_.size(), 0);
|
||||
// build
|
||||
size_t k = 0;
|
||||
while(true) {
|
||||
indices_t current;
|
||||
for(size_t d = 0; d < id.size(); d++)
|
||||
current.push_back(axes_[d].values[id[d]]);
|
||||
size_t sz = indices_.size();
|
||||
indices_[current] = sz;
|
||||
values_[current] = nullptr;
|
||||
ordered_indices_.push_back(current);
|
||||
id[order_[0]]++;
|
||||
while(id[order_[k]] == axes_[order_[k]].values.size()){
|
||||
if(k == id.size() - 1)
|
||||
return;
|
||||
id[order_[k++]] = 0;
|
||||
id[order_[k]]++;
|
||||
}
|
||||
k = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
|
||||
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
|
||||
init_indices();
|
||||
}
|
||||
|
||||
void distributed_tile::set_value(indices_t idx, Value *x) {
|
||||
assert(x->getType() == ty_ && "cannot set a value of different type");
|
||||
Value *&result = values_[idx];
|
||||
assert(!result && "value cannot be set twice");
|
||||
result = x;
|
||||
}
|
||||
|
||||
Value* distributed_tile::get_value(indices_t idx) {
|
||||
Value *result = values_.at(idx);
|
||||
assert(result && "value has not been set");
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned distributed_tile::get_linear_index(indices_t idx) {
|
||||
return indices_[idx];
|
||||
}
|
||||
|
||||
indices_t distributed_tile::get_ordered_indices(unsigned id) {
|
||||
return ordered_indices_.at(id);
|
||||
}
|
||||
|
||||
|
||||
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
|
||||
if(end < 0)
|
||||
end = ordered_indices_.size() + end + 1;
|
||||
for(unsigned i = start; i < end; i++)
|
||||
fn(ordered_indices_[i]);
|
||||
}
|
||||
|
||||
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
|
||||
int rank = sizes.size();
|
||||
int len = 1;
|
||||
for(int s: sizes)
|
||||
len *= s;
|
||||
|
||||
for(int i = 0; i < len; i++){
|
||||
indices_t idx(rank);
|
||||
int current = i;
|
||||
for(int k = 0; k < rank; k++){
|
||||
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
|
||||
current = current / sizes[k];
|
||||
}
|
||||
fn(idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Shared Tile */
|
||||
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
|
||||
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
|
||||
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
|
||||
if(dyn_cast<Constant>(arg)){
|
||||
cst = arg;
|
||||
non_cst = _0;
|
||||
return;
|
||||
}
|
||||
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
|
||||
non_cst = arg;
|
||||
cst = _0;
|
||||
return;
|
||||
}
|
||||
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
|
||||
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
|
||||
if(cst_lhs && cst_rhs){
|
||||
cst = arg;
|
||||
non_cst = _0;
|
||||
}
|
||||
else if(cst_lhs){
|
||||
cst = cst_lhs;
|
||||
non_cst = bin_op->getOperand(1);
|
||||
}
|
||||
else if(cst_rhs){
|
||||
cst = cst_rhs;
|
||||
non_cst = bin_op->getOperand(0);
|
||||
}
|
||||
else{
|
||||
non_cst = arg;
|
||||
cst = _0;
|
||||
}
|
||||
}
|
||||
|
||||
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
|
||||
non_cst_idx.clear();
|
||||
cst_idx.clear();
|
||||
for(Value *idx: arg_idx){
|
||||
Value *non_cst, *cst;
|
||||
extract_constant(idx, non_cst, cst);
|
||||
non_cst_idx.push_back(non_cst);
|
||||
cst_idx.push_back(cst);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
|
||||
const std::vector<int>& perm, const std::vector<int>& order,
|
||||
indices_t idx) {
|
||||
// strides
|
||||
std::vector<Value*> strides(order.size());
|
||||
strides[order[0]] = builder.getInt32(1);
|
||||
for(size_t i = 1; i < idx.size(); i++)
|
||||
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
|
||||
// result
|
||||
Value *result = builder.getInt32(0);
|
||||
for(size_t i = 0; i < strides.size(); i++)
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
|
||||
return result;
|
||||
}
|
||||
|
||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
|
||||
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
|
||||
return_vector_ = false;
|
||||
if(perm_.empty()){
|
||||
perm_.resize(shapes.size());
|
||||
std::iota(perm_.begin(), perm_.end(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
void shared_tile::set_value(indices_t idx, Value *value) {
|
||||
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
|
||||
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
|
||||
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
|
||||
builder_.CreateStore(value, ptr);
|
||||
}
|
||||
|
||||
void shared_tile::set_vector_size(unsigned vector_size) {
|
||||
vector_size_ = vector_size;
|
||||
}
|
||||
|
||||
void shared_tile::set_return_mode(bool return_vector){
|
||||
return_vector_ = return_vector;
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::get_value(indices_t idx) {
|
||||
indices_t non_cst_idx, cst_idx;
|
||||
extract_constant(idx, non_cst_idx, cst_idx);
|
||||
Value *&base_ptr = ptr_cache_[non_cst_idx];
|
||||
unsigned vector_size = vector_size_;
|
||||
Type *ty = ty_;
|
||||
if(ty->isHalfTy() && (vector_size % 2 == 0)){
|
||||
ty = IntegerType::get(ty->getContext(), 32);
|
||||
vector_size = vector_size / 2;
|
||||
}
|
||||
if(base_ptr == nullptr){
|
||||
// BasicBlock* store = builder_.GetInsertBlock();
|
||||
// if(!non_cst_idx.empty())
|
||||
// if(isa<Instruction>(non_cst_idx.front())){
|
||||
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
|
||||
// }
|
||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
|
||||
if(vector_size_ > 1){
|
||||
Type *vec_ty = VectorType::get(ty, vector_size);
|
||||
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
||||
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
|
||||
}
|
||||
// builder_.SetInsertPoint(store);
|
||||
}
|
||||
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
|
||||
Value *div = offset;
|
||||
if(vector_size_ > 1)
|
||||
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
||||
Value *ptr = builder_.CreateGEP(base_ptr, div);
|
||||
Value *result = builder_.CreateLoad(ptr);
|
||||
if(return_vector_ == false && vector_size_ > 1) {
|
||||
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
|
||||
result = builder_.CreateExtractElement(result, rem);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
174
lib/codegen/target.cc
Normal file
174
lib/codegen/target.cc
Normal file
@@ -0,0 +1,174 @@
|
||||
#include "triton/codegen/target.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include <iostream>
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
// base
|
||||
bool target::is_gpu() const {
|
||||
return is_gpu_;
|
||||
}
|
||||
|
||||
// AMD
|
||||
void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
|
||||
fn->setCallingConv(CallingConv::AMDGPU_KERNEL);
|
||||
}
|
||||
|
||||
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
|
||||
return builder.CreateCall(barrier, {});
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* group_id = get_block_id(module, builder, ax);
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
|
||||
return result;
|
||||
}
|
||||
|
||||
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
|
||||
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::amdgcn_workgroup_id_x,
|
||||
Intrinsic::amdgcn_workgroup_id_y,
|
||||
Intrinsic::amdgcn_workgroup_id_z
|
||||
};
|
||||
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
Value* group_id = builder.CreateCall(get_group_id, {});
|
||||
return group_id;
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::r600_read_ngroups_x,
|
||||
Intrinsic::r600_read_ngroups_y,
|
||||
Intrinsic::r600_read_ngroups_z
|
||||
};
|
||||
Value* get_num_group = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_num_group, {});
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::amdgcn_workitem_id_x,
|
||||
Intrinsic::amdgcn_workitem_id_y,
|
||||
Intrinsic::amdgcn_workitem_id_z
|
||||
};
|
||||
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_local_id, {});
|
||||
}
|
||||
|
||||
// NVIDIA
|
||||
|
||||
void nvidia_cu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn){
|
||||
// set metadata
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(fn),
|
||||
MDString::get(ctx, "kernel"),
|
||||
ValueAsMetadata::get(builder.getInt32(1))
|
||||
};
|
||||
module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
}
|
||||
|
||||
Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_barrier0);
|
||||
return builder.CreateCall(barrier, {});
|
||||
}
|
||||
|
||||
Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl);
|
||||
return builder.CreateCall(barrier, {});
|
||||
}
|
||||
|
||||
|
||||
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* group_id = get_block_id(module, builder, ax);
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
|
||||
return result;
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> cta_ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
||||
};
|
||||
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
|
||||
Value* cta_id = builder.CreateCall(get_cta_id, {});
|
||||
return cta_id;
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_tid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_tid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_tid_z
|
||||
};
|
||||
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_local_id, {});
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
||||
};
|
||||
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_nctaid, {});
|
||||
}
|
||||
|
||||
// CPU
|
||||
|
||||
void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
|
||||
// normal cpu functions can be kernels
|
||||
}
|
||||
|
||||
Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
// no barrier on CPU
|
||||
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
|
||||
}
|
||||
|
||||
Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
// no barrier on CPU
|
||||
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
|
||||
}
|
||||
|
||||
|
||||
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
|
||||
const Function *fn = builder.GetInsertBlock()->getParent();
|
||||
size_t num_params = fn->getFunctionType()->getNumParams();
|
||||
static std::array<const Argument*, 3> ids = {
|
||||
fn->arg_begin() + num_params - 3,
|
||||
fn->arg_begin() + num_params - 2,
|
||||
fn->arg_begin() + num_params - 1
|
||||
};
|
||||
return (Argument*)ids[ax];
|
||||
}
|
||||
|
||||
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
|
||||
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
|
||||
return result;
|
||||
}
|
||||
|
||||
Value* cpu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
return builder.getInt32(0);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
143
lib/codegen/transform/coalesce.cc
Normal file
143
lib/codegen/transform/coalesce.cc
Normal file
@@ -0,0 +1,143 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
|
||||
: align_(align), layout_(layouts) { }
|
||||
|
||||
// Find all values that are used as pointer operands in LD/ST
|
||||
void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
if(i && i->get_pointer_operand() == v)
|
||||
result.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
void coalesce::extract_ld(ir::io_inst* i, std::map<int, std::vector<ir::io_inst*>>& result) {
|
||||
ir::value *ptr = i->get_pointer_operand();
|
||||
auto contiguous = align_->contiguous(ptr);
|
||||
auto it = std::max_element(contiguous.begin(), contiguous.end());
|
||||
int axis = std::distance(contiguous.begin(), it);
|
||||
result[axis].push_back(i);
|
||||
}
|
||||
|
||||
ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
|
||||
std::map<ir::value*, ir::value*>& seen) {
|
||||
if(seen.find(x) != seen.end())
|
||||
return seen.at(x);
|
||||
auto i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction -- forward value
|
||||
if(!i)
|
||||
return x;
|
||||
// already in shared memory -- forward value
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(x)){
|
||||
return x;
|
||||
}
|
||||
// set insert point
|
||||
auto& inst_list = i->get_parent()->get_inst_list();
|
||||
auto pos = ++std::find(inst_list.begin(), inst_list.end(), i);
|
||||
builder.set_insert_point(pos);
|
||||
if(dynamic_cast<ir::load_inst*>(x)){
|
||||
ir::value *ret = builder.insert(ir::copy_to_shared_inst::create(x));
|
||||
return ret;
|
||||
}
|
||||
// default -- recursive clone
|
||||
ir::instruction *cloned = builder.insert(i->clone());
|
||||
seen[i] = cloned;
|
||||
// rematerialize operands
|
||||
for(ir::value *op: cloned->ops())
|
||||
cloned->replace_uses_of_with(op, rematerialize(op, builder, seen));
|
||||
return cloned;
|
||||
}
|
||||
|
||||
void coalesce::run(ir::module &mod) {
|
||||
size_t num_groups = layout_->num_layouts();
|
||||
|
||||
|
||||
for(size_t id = 0; id < num_groups; id++) {
|
||||
if(!layout_->get(id)->to_mma884())
|
||||
continue;
|
||||
// extract memory stores
|
||||
const auto& values = layout_->values_of(id);
|
||||
ir::value* dot = nullptr;
|
||||
for(ir::value *v: values)
|
||||
if(auto x = dynamic_cast<ir::dot_inst*>(v))
|
||||
dot = x;
|
||||
|
||||
ir::builder& builder = mod.get_builder();
|
||||
std::vector<ir::value*> worklist = {dot};
|
||||
std::set<ir::value*> seen;
|
||||
while(!worklist.empty()) {
|
||||
ir::value *current = worklist.back();
|
||||
seen.insert(current);
|
||||
worklist.pop_back();
|
||||
// stop if trunc
|
||||
if(auto x = dynamic_cast<ir::fp_trunc_inst*>(current)){
|
||||
builder.set_insert_point_after(x);
|
||||
ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x);
|
||||
builder.insert(rc);
|
||||
x->replace_all_uses_with(rc);
|
||||
rc->replace_uses_of_with(rc, x);
|
||||
break;
|
||||
}
|
||||
// recurse
|
||||
for(ir::user *u: current->get_users())
|
||||
if(seen.find(u) == seen.end())
|
||||
worklist.push_back(u);
|
||||
}
|
||||
}
|
||||
|
||||
// find values to rematerialize
|
||||
std::vector<ir::io_inst*> remat;
|
||||
for(size_t id = 0; id < num_groups; id++) {
|
||||
const auto& values = layout_->values_of(id);
|
||||
// extract pointers used in ld/st operations
|
||||
std::set<ir::io_inst*> io;
|
||||
for(ir::value *v: values)
|
||||
extract_io_use(v, io);
|
||||
// extract leading axes
|
||||
std::map<int, std::vector<ir::io_inst*>> axes;
|
||||
for(ir::io_inst *i: io){
|
||||
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->get_rank())
|
||||
extract_ld(i, axes);
|
||||
}
|
||||
// update list of values to rematerialize
|
||||
if(axes.empty())
|
||||
continue;
|
||||
for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
|
||||
remat.insert(remat.begin(), it->second.begin(), it->second.end());
|
||||
}
|
||||
// rematerialize values
|
||||
for(ir::io_inst *r: remat) {
|
||||
ir::builder& builder = mod.get_builder();
|
||||
// rematerialize operands
|
||||
std::map<ir::value*, ir::value*> seen;
|
||||
for(ir::value *op: r->ops())
|
||||
r->replace_uses_of_with(op, rematerialize(op, mod.get_builder(), seen));
|
||||
// copy to shared if load
|
||||
auto& inst_list = r->get_parent()->get_inst_list();
|
||||
auto pos = ++std::find(inst_list.begin(), inst_list.end(), r);
|
||||
builder.set_insert_point(pos);
|
||||
if(dynamic_cast<ir::load_inst*>(r)){
|
||||
ir::instruction *cts = builder.insert(ir::copy_to_shared_inst::create(r));
|
||||
r->replace_all_uses_with(cts);
|
||||
cts->replace_uses_of_with(cts, r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
95
lib/codegen/transform/cts.cc
Normal file
95
lib/codegen/transform/cts.cc
Normal file
@@ -0,0 +1,95 @@
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
inline bool is_shmem_op(ir::instruction* i, int op) {
|
||||
if(i->get_id() == ir::INST_DOT)
|
||||
return op==0 || op==1;
|
||||
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
||||
return op==0;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return op==0;
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool is_shmem_res(ir::value* v){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return false;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_REDUCE)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// run pass on module
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction
|
||||
if(!i) {
|
||||
builder.set_insert_point(parent);
|
||||
ir::value *copy;
|
||||
if(to_shared)
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
return;
|
||||
}
|
||||
// phi node
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
|
||||
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
|
||||
return;
|
||||
}
|
||||
// already in shared memory
|
||||
if(to_shared && is_shmem_res(i))
|
||||
return;
|
||||
// copy
|
||||
builder.set_insert_point_after(i);
|
||||
ir::value *copy;
|
||||
if(to_shared)
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
}
|
||||
|
||||
void cts::run(ir::module &mod) {
|
||||
// Add shared copies
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function* fn: mod.get_function_list()){
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
size_t num_op = i->get_num_operands();
|
||||
// copy to shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(is_shmem_op(i, k))
|
||||
add_copy(i, i->get_operand(k), builder, true);
|
||||
// copy from shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(!dynamic_cast<ir::phi_node*>(i) &&
|
||||
!is_shmem_op(i,k) &&
|
||||
is_shmem_res(i->get_operand(k))){
|
||||
add_copy(i, i->get_operand(k), builder, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
75
lib/codegen/transform/dce.cc
Normal file
75
lib/codegen/transform/dce.cc
Normal file
@@ -0,0 +1,75 @@
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
void dce::run(ir::module &mod) {
|
||||
std::list<ir::instruction*> work_list;
|
||||
std::set<ir::instruction*> marked;
|
||||
|
||||
// initialize work-list
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
switch(i->get_id()){
|
||||
case ir::INST_RETURN:
|
||||
case ir::INST_UNCOND_BRANCH:
|
||||
case ir::INST_COND_BRANCH:
|
||||
case ir::INST_UNMASKED_STORE:
|
||||
case ir::INST_MASKED_STORE:
|
||||
case ir::INST_ATOMIC_ADD:
|
||||
case ir::INST_ATOMIC_CAS:
|
||||
case ir::INST_ATOMIC_EXCH:
|
||||
case ir::INST_BARRIER: {
|
||||
work_list.push_back(i);
|
||||
marked.insert(i);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mark -- ignore branches
|
||||
while(!work_list.empty()){
|
||||
ir::instruction* current = work_list.back();
|
||||
work_list.pop_back();
|
||||
// mark instruction operands
|
||||
for(ir::value* op: current->ops()) {
|
||||
if(auto *i = dynamic_cast<ir::instruction*>(op)){
|
||||
if(marked.insert(i).second)
|
||||
work_list.push_back(i);
|
||||
}
|
||||
}
|
||||
// TODO: mark last intstruction of current's reverse-dominance frontier
|
||||
}
|
||||
|
||||
// sweep -- delete non-branch unmarked instructions
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(marked.find(i) == marked.end())
|
||||
to_delete.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// delete
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
76
lib/codegen/transform/disassociate.cc
Normal file
76
lib/codegen/transform/disassociate.cc
Normal file
@@ -0,0 +1,76 @@
|
||||
#include "triton/codegen/transform/disassociate.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void extract_retile_chain(ir::user *root,
|
||||
std::map<int, std::set<ir::user*>>& result,
|
||||
int depth,
|
||||
std::set<ir::value*>& seen) {
|
||||
if(!seen.insert(root).second)
|
||||
return;
|
||||
result[depth].insert(root);
|
||||
if(dynamic_cast<ir::make_range*>(root) ||
|
||||
dynamic_cast<ir::splat_inst*>(root)){
|
||||
return;
|
||||
}
|
||||
for(ir::value *op: root->ops()){
|
||||
ir::user *u = dynamic_cast<ir::user*>(op);
|
||||
if(!u)
|
||||
continue;
|
||||
extract_retile_chain(u, result, depth + 1, seen);
|
||||
}
|
||||
}
|
||||
|
||||
void disassociate::run(ir::module &mod) {
|
||||
ir::builder &bld = mod.get_builder();
|
||||
|
||||
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(dynamic_cast<ir::reshape_inst*>(i)){
|
||||
std::map<int, std::set<ir::user*>> chains;
|
||||
std::set<ir::value*> seen;
|
||||
if(!dynamic_cast<ir::user*>(i->get_operand(0)))
|
||||
return;
|
||||
extract_retile_chain(i, chains, 0, seen);
|
||||
if(chains.size())
|
||||
clone_info[i] = chains;
|
||||
}
|
||||
});
|
||||
|
||||
for(const auto& x: clone_info){
|
||||
int depth = 1;
|
||||
std::map<ir::instruction*, ir::instruction*> clone_map;
|
||||
while(x.second.find(depth) != x.second.end()){
|
||||
// clone all users
|
||||
const auto& remat = x.second.at(depth);
|
||||
for(ir::user* u: remat){
|
||||
ir::instruction *y = (ir::instruction*)u;
|
||||
ir::instruction *cloned = y->clone();
|
||||
bld.set_insert_point(y);
|
||||
bld.insert(cloned);
|
||||
clone_map[y] = cloned;
|
||||
// replace operands of parents
|
||||
if(depth > 1)
|
||||
for(ir::user* ux: x.second.at(depth - 1))
|
||||
clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned);
|
||||
else
|
||||
x.first->replace_uses_of_with(y, cloned);
|
||||
}
|
||||
depth += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
168
lib/codegen/transform/membar.cc
Normal file
168
lib/codegen/transform/membar.cc
Normal file
@@ -0,0 +1,168 @@
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
bool membar::intersect(const interval_vec_t &X, interval_t x) {
|
||||
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
|
||||
bool left_intersect = y.first <= x.first && x.first < y.second;
|
||||
bool right_intersect = y.first <= x.second && x.second < y.second;
|
||||
return left_intersect || right_intersect;
|
||||
});
|
||||
}
|
||||
|
||||
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
|
||||
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
|
||||
return intersect(X, y);
|
||||
});
|
||||
}
|
||||
|
||||
void membar::add_reference(ir::value *v, interval_vec_t &res){
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return;
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
return;
|
||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||
if(!layout)
|
||||
return;
|
||||
if(alloc_->has_offset(layout)){
|
||||
unsigned offset = alloc_->offset(layout);
|
||||
res.push_back(interval_t(offset, offset + layout->get_size()));
|
||||
}
|
||||
}
|
||||
|
||||
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
for(ir::value *op: i->ops())
|
||||
add_reference(op, res);
|
||||
}
|
||||
|
||||
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
|
||||
add_reference(i, res);
|
||||
}
|
||||
|
||||
void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
||||
std::set<ir::value*> incoming;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
|
||||
assert(inc_val);
|
||||
if(incoming.insert(inc_val).second){
|
||||
ir::basic_block *block = inc_val->get_parent();
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
builder.create_barrier();
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
builder.set_insert_point(instr);
|
||||
builder.create_barrier();
|
||||
}
|
||||
}
|
||||
|
||||
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
|
||||
membar::interval_vec_t result;
|
||||
for(auto x: intervals)
|
||||
for(interval_t i: x)
|
||||
result.push_back(i);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<membar::interval_vec_t,
|
||||
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
|
||||
const interval_vec_t &written_to,
|
||||
const interval_vec_t &read_from,
|
||||
std::set<ir::instruction*>& insert_loc,
|
||||
std::set<ir::value*>& safe_war) {
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
interval_vec_t new_written_to = written_to;
|
||||
interval_vec_t new_read_from = read_from;
|
||||
|
||||
for(ir::instruction *i: instructions){
|
||||
interval_vec_t read, written;
|
||||
get_read_intervals(i, read);
|
||||
get_written_intervals(i, written);
|
||||
bool read_after_write = intersect(new_written_to, read);
|
||||
bool write_after_read = intersect(new_read_from, written);
|
||||
// double buffering
|
||||
if(safe_war.find(i) != safe_war.end()){
|
||||
write_after_read = false;
|
||||
read_after_write = false;
|
||||
}
|
||||
// record hazards
|
||||
if(read_after_write || write_after_read) {
|
||||
insert_loc.insert(i);
|
||||
new_written_to.clear();
|
||||
new_read_from.clear();
|
||||
}
|
||||
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
|
||||
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
|
||||
}
|
||||
return std::make_pair(new_written_to, new_read_from);
|
||||
}
|
||||
|
||||
void membar::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// extract phi-node associates with double-buffered
|
||||
// shared-memory copies. These can be read from and written to
|
||||
// without needing synchronization
|
||||
std::set<ir::value*> safe_war;
|
||||
for(const auto& x: layouts_->get_all()){
|
||||
analysis::shared_layout* layout = x.second->to_shared();
|
||||
if(!layout || !layout->get_double_buffer())
|
||||
continue;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(v != layout->get_double_buffer()->phi)
|
||||
safe_war.insert(v);
|
||||
}
|
||||
|
||||
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, interval_vec_t> written_to;
|
||||
std::map<ir::basic_block*, interval_vec_t> read_from;
|
||||
std::set<ir::instruction*> insert_locs;
|
||||
size_t n_inserted_im1 = 0;
|
||||
bool done = false;
|
||||
do{
|
||||
// find barrier location
|
||||
for(ir::basic_block *block: rpo){
|
||||
// written to
|
||||
std::vector<interval_vec_t> pred_written_to;
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_written_to.push_back(written_to[pred]);
|
||||
// read from
|
||||
std::vector<interval_vec_t> pred_read_from;
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_read_from.push_back(read_from[pred]);
|
||||
// apply transfer function
|
||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war);
|
||||
written_to[block] = result.first;
|
||||
read_from[block] = result.second;
|
||||
}
|
||||
size_t n_inserted_i = insert_locs.size();
|
||||
done = (n_inserted_im1 == n_inserted_i);
|
||||
n_inserted_im1 = n_inserted_i;
|
||||
}while(!done);
|
||||
for(ir::instruction* i: insert_locs)
|
||||
insert_barrier(i, builder);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
191
lib/codegen/transform/peephole.cc
Normal file
191
lib/codegen/transform/peephole.cc
Normal file
@@ -0,0 +1,191 @@
|
||||
#include <algorithm>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
|
||||
const std::vector<int>& perm) {
|
||||
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
|
||||
// transpose operands
|
||||
std::vector<ir::value*> incs;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm));
|
||||
// create phi for transposed values
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
result->add_incoming(incs[n], phi->get_incoming_block(n));
|
||||
return result;
|
||||
}
|
||||
else if(auto i = dynamic_cast<ir::instruction*>(value)){
|
||||
ir::basic_block* block = i->get_parent();
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
it++;
|
||||
builder.set_insert_point(it);
|
||||
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
|
||||
trans->set_operand(0, i);
|
||||
return trans;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
|
||||
auto trans = dynamic_cast<ir::trans_inst*>(value);
|
||||
if(!trans)
|
||||
return false;
|
||||
auto users = trans->get_users();
|
||||
auto ops = trans->ops();
|
||||
if(users.size() > 1 || ops.size() > 1)
|
||||
return false;
|
||||
ir::value* op = *ops.begin();
|
||||
// trans(phi) -> phi(trans(), trans()...)
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(op);
|
||||
if(!phi)
|
||||
return false;
|
||||
ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm());
|
||||
if(!new_phi)
|
||||
return false;
|
||||
trans->replace_all_uses_with(new_phi);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
// dot(a, b, 0) + c -> dot(a, b, c)
|
||||
auto add = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(add && add->get_op() == ir::binary_op_t::FAdd) {
|
||||
ir::value *lhs = add->get_operand(0);
|
||||
ir::value *rhs = add->get_operand(1);
|
||||
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
|
||||
ir::dot_inst *rhs_dot = dynamic_cast<ir::dot_inst*>(rhs);
|
||||
if(!lhs_dot && !rhs_dot)
|
||||
return false;
|
||||
ir::dot_inst *dot = lhs_dot ? lhs_dot : rhs_dot;
|
||||
ir::value *other = (dot == lhs) ? rhs : lhs;
|
||||
ir::value *acc = dot->get_operand(2);
|
||||
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
|
||||
ir::constant_fp *_0 = nullptr;
|
||||
if(splat)
|
||||
_0 = dynamic_cast<ir::constant_fp*>(splat->get_operand(0));
|
||||
if(!(_0 && _0->get_value() == 0.0))
|
||||
return false;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
builder.set_insert_point(add);
|
||||
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name()));
|
||||
add->replace_all_uses_with(new_dot);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||
if(!x)
|
||||
return false;
|
||||
ir::value *arg = x->get_operand(0);
|
||||
auto shapes = arg->get_type()->get_tile_shapes();
|
||||
if(shapes[x->get_axis()] == 1){
|
||||
builder.set_insert_point(x);
|
||||
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
|
||||
x->replace_all_uses_with(new_red);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
|
||||
auto binop = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
|
||||
ir::value *lhs = binop->get_operand(0);
|
||||
ir::value *rhs = binop->get_operand(1);
|
||||
ir::constant_int *_1_lhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs))
|
||||
_1_lhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
ir::constant_int *_1_rhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs))
|
||||
_1_rhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(_1_lhs){
|
||||
binop->replace_all_uses_with(rhs);
|
||||
return true;
|
||||
}
|
||||
else if(_1_rhs){
|
||||
binop->replace_all_uses_with(lhs);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
|
||||
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
|
||||
if(!x)
|
||||
return false;
|
||||
auto y = dynamic_cast<ir::getelementptr_inst*>(x->get_pointer_operand());
|
||||
if(!y)
|
||||
return false;
|
||||
auto idx = *y->idx_begin();
|
||||
auto z = dynamic_cast<ir::binary_operator*>(idx);
|
||||
if(!z)
|
||||
return false;
|
||||
bool is_sub = z->get_op() == ir::binary_op_t::Sub;
|
||||
auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
|
||||
bool is_lhs_0 = lhs && (lhs->get_value()==0);
|
||||
bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();
|
||||
if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){
|
||||
x->replace_all_uses_with(y->get_pointer_operand());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void peephole::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// keep track of whether any modification was made
|
||||
std::set<ir::value*> seen;
|
||||
size_t n_seen;
|
||||
|
||||
// rewrite dots first
|
||||
do{
|
||||
n_seen = seen.size();
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(seen.find(i) != seen.end())
|
||||
continue;
|
||||
bool was_modified = rewrite_dot(i, builder);
|
||||
if(was_modified){
|
||||
seen.insert(i);
|
||||
}
|
||||
}
|
||||
}while(seen.size() != n_seen);
|
||||
|
||||
// rewrite other ops
|
||||
seen.clear();
|
||||
do{
|
||||
n_seen = seen.size();
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(seen.find(i) != seen.end())
|
||||
continue;
|
||||
bool was_modified = false;
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
if(was_modified)
|
||||
seen.insert(i);
|
||||
}
|
||||
}while(seen.size() != n_seen);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
264
lib/codegen/transform/reassociate.cc
Normal file
264
lib/codegen/transform/reassociate.cc
Normal file
@@ -0,0 +1,264 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
|
||||
ir::binary_operator *bin_op = dynamic_cast<ir::binary_operator*>(x);
|
||||
bool is_bin_add = bin_op && bin_op->get_op()== ir::binary_op_t::Add;
|
||||
if(is_bin_add)
|
||||
return (ir::instruction*)x;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline bool is_cst(ir::value *x) {
|
||||
if(dynamic_cast<ir::constant*>(x))
|
||||
return true;
|
||||
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
|
||||
return is_cst(v->get_operand(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
ir::builder &builder,
|
||||
ir::value *&noncst,
|
||||
ir::value *&cst){
|
||||
// value doesn't change by default
|
||||
ir::value* new_value = old_value;
|
||||
cst = nullptr;
|
||||
noncst = old_value;
|
||||
|
||||
// handle retiling
|
||||
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
|
||||
auto shapes = op->get_type()->get_tile_shapes();
|
||||
ir::value *old_arg = op->get_operand(0);
|
||||
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
|
||||
// retile(x + y) = retile(x) + retile(y)
|
||||
if(ir::instruction* bin_add = is_bin_add(new_arg))
|
||||
if(cst){
|
||||
ir::value *old_lhs = bin_add->get_operand(0);
|
||||
ir::value *old_rhs = bin_add->get_operand(1);
|
||||
ir::value *new_lhs = nullptr;
|
||||
ir::value *new_rhs = nullptr;
|
||||
if(dynamic_cast<ir::reshape_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_reshape(old_lhs, shapes);
|
||||
new_rhs = builder.create_reshape(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
if(dynamic_cast<ir::broadcast_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_broadcast(old_lhs, shapes);
|
||||
new_rhs = builder.create_broadcast(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
if(dynamic_cast<ir::splat_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_splat(old_lhs, shapes);
|
||||
new_rhs = builder.create_splat(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle binary addition
|
||||
if(ir::instruction* op = is_bin_add(old_value)){
|
||||
builder.set_insert_point(op);
|
||||
std::string name = op->get_name();
|
||||
ir::value *lhs = reassociate_idx(op->get_operand (0), builder, noncst, cst);
|
||||
ir::value *rhs = reassociate_idx(op->get_operand(1), builder, noncst, cst);
|
||||
builder.set_insert_point(op);
|
||||
// (x + y) + z
|
||||
if(ir::instruction* bin_lhs = is_bin_add(lhs)){
|
||||
ir::value *llhs = bin_lhs->get_operand(0);
|
||||
ir::value *rlhs = bin_lhs->get_operand(1);
|
||||
// (cst + x) + y -> cst + (x + y)
|
||||
if(is_cst(llhs))
|
||||
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs), name);
|
||||
// (x + cst) + y -> cst + (x + y)
|
||||
if(is_cst(rlhs))
|
||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
|
||||
}
|
||||
// x + (y + z)
|
||||
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
|
||||
ir::value *lrhs = bin_rhs->get_operand(0);
|
||||
ir::value *rrhs = bin_rhs->get_operand(1);
|
||||
// x + (cst + y) -> cst + (x + y)
|
||||
if(is_cst(lrhs))
|
||||
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), name, cst);
|
||||
// x + (y + cst) -> cst + (x + y)
|
||||
if(is_cst(rrhs))
|
||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
|
||||
}
|
||||
}
|
||||
// extract constant and non-constant
|
||||
if(ir::instruction *bin_add = is_bin_add(new_value)){
|
||||
ir::value *new_lhs = bin_add->get_operand(0);
|
||||
ir::value *new_rhs = bin_add->get_operand(1);
|
||||
if(is_cst(new_lhs)){
|
||||
cst = new_lhs;
|
||||
noncst = new_rhs;
|
||||
}
|
||||
if(is_cst(new_rhs)){
|
||||
cst = new_rhs;
|
||||
noncst = new_lhs;
|
||||
}
|
||||
}
|
||||
// clean-up if some re-ordering happened
|
||||
if(old_value != new_value)
|
||||
old_value->replace_all_uses_with(new_value);
|
||||
return new_value;
|
||||
}
|
||||
|
||||
/* run */
|
||||
void reassociate::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
|
||||
// constant_range -> nv_dynamic_program_idx + nv_static_program_idx
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::make_range*> ranges;
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::value* op: i->ops())
|
||||
if(auto *range = dynamic_cast<ir::make_range*>(op))
|
||||
ranges.push_back(range);
|
||||
}
|
||||
|
||||
builder.set_insert_point(rpo.front()->get_first_non_phi());
|
||||
for(ir::make_range* old_range: ranges){
|
||||
ir::value* dyn_range = builder.insert(ir::make_range_dyn::create(old_range->get_type()));
|
||||
ir::value* static_range = ir::make_range_sta::get(old_range);
|
||||
ir::value* new_range = builder.create_add(dyn_range, static_range);
|
||||
old_range->replace_all_uses_with(new_range);
|
||||
}
|
||||
}
|
||||
|
||||
// reassociate
|
||||
std::map<ir::value*, cst_info> infos;
|
||||
std::set<ir::value*> replaced;
|
||||
size_t n_replaced;
|
||||
do{
|
||||
n_replaced = replaced.size();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// retiling
|
||||
if(ir::retile_inst *rt = dynamic_cast<ir::retile_inst*>(i)) {
|
||||
ir::value* op = rt->get_operand(0);
|
||||
if(infos.find(op) != infos.end()){
|
||||
builder.set_insert_point(rt);
|
||||
ir::getelementptr_inst* sta = infos.at(op).sta_ptr;
|
||||
ir::value* dyn = infos.at(op).dyn_ptr;
|
||||
ir::value* cst = *sta->idx_begin();
|
||||
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
|
||||
auto shapes = rt->get_type()->get_tile_shapes();
|
||||
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
|
||||
ir::value* broadcast = builder.create_broadcast(cst, shapes);
|
||||
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
|
||||
infos[rt] = cst_info{ndyn, nsta};
|
||||
}
|
||||
}
|
||||
}
|
||||
// getelementptr instruction
|
||||
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
|
||||
if(replaced.find(pz) != replaced.end())
|
||||
continue;
|
||||
// unpack GEP instruction
|
||||
ir::value* py = pz->get_pointer_operand();
|
||||
ir::value* offset = *pz->idx_begin();
|
||||
// reassociate index
|
||||
ir::value *sta = nullptr;
|
||||
ir::value *dyn = offset;
|
||||
reassociate_idx(offset, builder, dyn, sta);
|
||||
if(sta){
|
||||
builder.set_insert_point(pz);
|
||||
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
|
||||
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
||||
pz->replace_all_uses_with(sta_ptr);
|
||||
infos[sta_ptr].dyn_ptr = dyn_ptr;
|
||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
||||
replaced.insert(pz);
|
||||
}
|
||||
// reassociate pointer argument
|
||||
if(infos.find(py) != infos.end()){
|
||||
builder.set_insert_point(pz);
|
||||
ir::getelementptr_inst *sta = infos[py].sta_ptr;
|
||||
ir::value *dyn = infos[py].dyn_ptr;
|
||||
ir::value *cst = *sta->idx_begin();
|
||||
ir::value *off = *pz->idx_begin();
|
||||
ir::value *pz_dyn = builder.create_gep(dyn, {off});
|
||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
|
||||
pz->replace_all_uses_with(pz_sta);
|
||||
infos[pz_sta].dyn_ptr = pz_dyn;
|
||||
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
||||
replaced.insert(pz);
|
||||
}
|
||||
// reassociate phi-node pointer
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
|
||||
// only optimize the case where py = phi pa, pz for now
|
||||
std::vector<ir::value*> ops = phi->ops();
|
||||
if(ops.size() != 2)
|
||||
continue;
|
||||
if(ops[0] != pz && ops[1] != pz)
|
||||
continue;
|
||||
// grab incoming
|
||||
size_t idx_z = (ops[0] == pz) ? 0 : 1;
|
||||
size_t idx_a = (ops[0] == pz) ? 1 : 0;
|
||||
// check if pa is known to have constant offset
|
||||
ir::value *vpa = phi->get_incoming_value(idx_a);
|
||||
auto it_a = infos.find(vpa);
|
||||
if(it_a == infos.end())
|
||||
continue;
|
||||
// unpack dynamically/statically offset pointer
|
||||
ir::value *pa_dyn = it_a->second.dyn_ptr;
|
||||
ir::getelementptr_inst *pa_sta = it_a->second.sta_ptr;
|
||||
ir::value *pz = phi->get_incoming_value(idx_z);
|
||||
// extract offset
|
||||
ir::value *off = *pa_sta->idx_begin();
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node *phi_dyn = builder.create_phi(phi->get_type(), 2);
|
||||
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
|
||||
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
|
||||
// re-add the offset
|
||||
ir::value *phi_sta = builder.create_gep(phi_dyn, {off}, phi->get_name() + "_sta");
|
||||
phi->replace_all_uses_with(phi_sta);
|
||||
// remove offset from pz
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
|
||||
auto insts = x->get_parent()->get_inst_list();
|
||||
auto it = std::find(insts.begin(), insts.end(), x);
|
||||
it++;
|
||||
builder.set_insert_point(*it);
|
||||
}
|
||||
ir::value *_0 = builder.get_int32(0);
|
||||
if(off->get_type()->is_tile_ty())
|
||||
_0 = builder.create_splat(_0, off->get_type()->get_tile_shapes());
|
||||
ir::value *neg_off = builder.create_sub(_0, off);
|
||||
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
|
||||
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
|
||||
infos[phi_sta].dyn_ptr = phi_dyn;
|
||||
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
|
||||
replaced.insert(phi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}while(replaced.size() != n_replaced);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user