[PYTHON] Cleaned up legacy code; added simple standalone compilation API (#22)
This commit is contained in:
@@ -1,5 +0,0 @@
|
||||
file(GLOB_RECURSE CODEGEN_SRC *.cc)
|
||||
|
||||
add_library(TritonCodeGen
|
||||
${CODEGEN_SRC}
|
||||
)
|
@@ -1,533 +0,0 @@
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
// Function for extended Euclidean Algorithm
|
||||
int gcd_impl(int a, int b, int *x, int *y)
|
||||
{
|
||||
// Base Case
|
||||
if (a == 0)
|
||||
{
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*y = x1;
|
||||
|
||||
return gcd;
|
||||
}
|
||||
|
||||
int gcd(int a, int b) {
|
||||
int x, y;
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
|
||||
inline int lcm(int a, int b) {
|
||||
return (a * b) / gcd(a, b);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
||||
return map[i] = value;
|
||||
}
|
||||
|
||||
/*
|
||||
* is constant
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::get_shapes(ir::value *v) {
|
||||
ir::type *ty = v->get_type();
|
||||
if(ty->is_block_ty())
|
||||
return ty->get_block_shapes();
|
||||
else
|
||||
return {1};
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<cst_info> result(shapes.size(), cst_info{1, 0});
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto it = is_constant_.find(inc);
|
||||
if(it != is_constant_.end())
|
||||
result = it->second;
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto cst = populate_is_constant(inc);
|
||||
for(size_t d = 0; d < cst.size(); d++)
|
||||
result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
ir::value* op = x->get_operand(0);
|
||||
std::vector<cst_info> result;
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(auto d: shapes)
|
||||
result.push_back(cst_info{d, op_cst[0].value});
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < x_shapes.size(); d ++){
|
||||
cst_info ax ;
|
||||
if(x_shapes[d] == 1)
|
||||
ax = {1, op_cst[current].value};
|
||||
else if(!is_skewed
|
||||
&& x_shapes[d] == op_shapes[current])
|
||||
ax = {x_shapes[d], op_cst[current++].value};
|
||||
else {
|
||||
is_skewed = true;
|
||||
ax = {x_shapes[d], 0};
|
||||
}
|
||||
result.push_back(ax);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
if(op_shapes[d] == 1)
|
||||
result.push_back(cst_info{x_shapes[d], op_cst[d].value});
|
||||
else
|
||||
result.push_back(op_cst[d]);
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
auto max_contiguous = populate_max_contiguous(lhs_op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
cst_info ax;
|
||||
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
|
||||
// todo might not be entirely true
|
||||
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
|
||||
ax = {num_constants, 0};
|
||||
}
|
||||
else
|
||||
ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0};
|
||||
result.push_back(ax);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
std::vector<cst_info> result;
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0});
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) {
|
||||
auto shapes = get_shapes(v);
|
||||
std::vector<cst_info> result(shapes.size(), {1, 0});
|
||||
return add_to_cache(v, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
if(is_constant_.find(v) != is_constant_.end())
|
||||
return is_constant_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_is_constant_phi(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_is_constant_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_is_constant_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_is_constant_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_is_constant_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_is_constant_gep(x);
|
||||
return populate_is_constant_default(v);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* max contiguous
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto it = max_contiguous_.find(inc);
|
||||
if(it != max_contiguous_.end())
|
||||
result = it->second;
|
||||
}
|
||||
add_to_cache(x, result, max_contiguous_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto contiguous = populate_max_contiguous(inc);
|
||||
for(size_t d = 0; d < result.size(); d++)
|
||||
result[d] = std::min(result[d], contiguous[d]);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
result.push_back({1});
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < shapes.size(); d ++){
|
||||
if(shapes[d] == 1)
|
||||
result.push_back(1);
|
||||
else if(!is_skewed
|
||||
&& shapes[d] == op_shapes[current])
|
||||
result.push_back(op_mc[current++]);
|
||||
else {
|
||||
is_skewed = true;
|
||||
result.push_back(1);
|
||||
}
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
if(op_shapes[d] == 1)
|
||||
result.push_back(1);
|
||||
else
|
||||
result.push_back(op_mc[d]);
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
auto lhs_cst_info = populate_is_constant(lhs);
|
||||
auto rhs_cst_info = populate_is_constant(rhs);
|
||||
auto lhs_starting_multiple = populate_starting_multiple(lhs);
|
||||
auto rhs_starting_multiple = populate_starting_multiple(rhs);
|
||||
std::vector<unsigned> result;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
unsigned value = 1;
|
||||
if(x->is_int_rem() && rhs_starting_multiple[d] > 0){
|
||||
value = std::min(lhs_max_contiguous[d], rhs_starting_multiple[d]);
|
||||
}
|
||||
if(x->is_int_mult()){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(rhs_cst_info[d].value == 1)
|
||||
lvalue = lhs_max_contiguous[d];
|
||||
if(lhs_cst_info[d].value == 1)
|
||||
rvalue = rhs_max_contiguous[d];
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
if(x->is_int_add_sub()){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]);
|
||||
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]);
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
result.push_back(value);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
auto lhs_cst_info = populate_is_constant(lhs);
|
||||
auto rhs_cst_info = populate_is_constant(rhs);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(lhs_cst_info[d].num_cst)
|
||||
lvalue = rhs_max_contiguous[d];
|
||||
if(rhs_cst_info[d].num_cst)
|
||||
rvalue = lhs_max_contiguous[d];
|
||||
result[d] = std::max(lvalue, rvalue);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
||||
if(!v->get_type()->is_block_ty())
|
||||
return add_to_cache(v, {1}, max_contiguous_);
|
||||
auto shapes = v->get_type()->get_block_shapes();
|
||||
if(dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
|
||||
auto result = populate_max_contiguous(v->get_operand(0));
|
||||
return add_to_cache(v, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||
return max_contiguous_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
|
||||
if(max_contiguous > 0)
|
||||
return add_to_cache(x, {max_contiguous}, max_contiguous_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_max_contiguous_cast(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_max_contiguous_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_max_contiguous_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_max_contiguous_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_max_contiguous_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_max_contiguous_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_max_contiguous_phi(x);
|
||||
return populate_max_contiguous_default(v);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* starting multiple
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){
|
||||
auto shapes = get_shapes(x);
|
||||
auto op = populate_starting_multiple(x->get_operand(0));
|
||||
std::vector<unsigned> result(shapes.size(), op[0]);
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){
|
||||
auto op = populate_starting_multiple(x->get_operand(0));
|
||||
auto op_shapes = get_shapes(x->get_operand(0));
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < shapes.size(); d ++){
|
||||
if(shapes[d] == 1)
|
||||
result[d] = 1;
|
||||
else if(!is_skewed
|
||||
&& shapes[d] == op_shapes[current])
|
||||
result[d] = op[current++];
|
||||
else {
|
||||
is_skewed = true;
|
||||
result[d] = 1;
|
||||
}
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){
|
||||
auto lhs = populate_starting_multiple(x->get_operand(0));
|
||||
auto rhs = populate_starting_multiple(x->get_operand(1));
|
||||
std::vector<unsigned> result(lhs.size(), 1);
|
||||
for(size_t d = 0; d < lhs.size(); d++){
|
||||
if(x->is_int_mult())
|
||||
result[d] = lhs[d] * rhs[d];
|
||||
if(x->is_int_add_sub())
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
if(x->is_int_div())
|
||||
result[d] = 1;
|
||||
if(x->is_int_rem() && rhs[d] > 1){
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
}
|
||||
if(x->is_shl())
|
||||
result[d] = lhs[d] << rhs[d];
|
||||
if(x->is_shr())
|
||||
result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){
|
||||
auto lhs = populate_starting_multiple(x->get_operand(0));
|
||||
auto rhs = populate_starting_multiple(x->get_operand(1));
|
||||
std::vector<unsigned> result(lhs.size(), 1);
|
||||
for(size_t d = 0; d < lhs.size(); d++){
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
// std::cout << "starting multiple: " << x->get_name() << " " << d << " " << result[d] << std::endl;
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
|
||||
auto shape = get_shapes(x);
|
||||
std::vector<unsigned> result(shape.size(), 1);
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
if(starting_multiple_.find(inc) != starting_multiple_.end())
|
||||
result = starting_multiple_.at(inc);
|
||||
}
|
||||
add_to_cache(x, result, starting_multiple_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto sm = populate_starting_multiple(inc);
|
||||
for(size_t d = 0; d < result.size(); d++)
|
||||
result[d] = gcd(result[d], sm[d]);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
ir::type* ty = v->get_type();
|
||||
if(ty->is_block_ty()) {
|
||||
return add_to_cache(v, ty->get_block_shapes(), starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||
for(auto attr: attributes){
|
||||
if(attr.get_kind() == ir::multiple_of){
|
||||
return add_to_cache(x, {attr.get_value()}, starting_multiple_);
|
||||
}
|
||||
if(attr.get_kind() == ir::aligned){
|
||||
ir::type* ty = x->get_type()->get_pointer_element_ty();
|
||||
int nbits = ty->get_primitive_size_in_bits();
|
||||
int nbytes = std::max<int>(nbits / 8, 1);
|
||||
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
|
||||
}
|
||||
}
|
||||
}
|
||||
return add_to_cache(v, {1}, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(multiple_of > 0)
|
||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
|
||||
return populate_starting_multiple_cast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_starting_multiple_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_starting_multiple_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_starting_multiple_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_starting_multiple_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_starting_multiple_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_starting_multiple_phi(x);
|
||||
return populate_starting_multiple_default(v);
|
||||
}
|
||||
|
||||
|
||||
unsigned align::get(ir::value *v, unsigned ax) const {
|
||||
unsigned starting_multiple = starting_multiple_.at(v)[ax];
|
||||
unsigned max_contiguous = max_contiguous_.at(v)[ax];
|
||||
return std::min(starting_multiple, max_contiguous);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::contiguous(ir::value* v) const {
|
||||
return max_contiguous_.at(v);
|
||||
}
|
||||
|
||||
|
||||
void align::populate(ir::value *v) {
|
||||
populate_is_constant(v);
|
||||
populate_starting_multiple(v);
|
||||
populate_max_contiguous(v);
|
||||
|
||||
}
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
|
||||
// ir::for_each_value(mod, [this](ir::value* v) {
|
||||
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
|
||||
// });
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,101 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
void allocation::run(ir::module &mod) {
|
||||
using std::max;
|
||||
using std::min;
|
||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||
|
||||
std::vector<shared_layout*> I;
|
||||
for(auto x: liveness_->get())
|
||||
I.push_back(x.first);
|
||||
std::vector<shared_layout*> J = I;
|
||||
|
||||
triples_map_type H;
|
||||
H.insert({0, segment{0, INT_MAX}});
|
||||
|
||||
std::vector<shared_layout*> V;
|
||||
std::map<shared_layout*, unsigned> starts;
|
||||
while(!J.empty()){
|
||||
auto h_it = H.begin();
|
||||
unsigned w = h_it->first;
|
||||
segment xh = h_it->second;
|
||||
H.erase(h_it);
|
||||
auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){
|
||||
segment xj = liveness_->get(JJ);
|
||||
bool res = xj.intersect(xh);
|
||||
for(auto val: H)
|
||||
res = res && !val.second.intersect(xj);
|
||||
return res;
|
||||
});
|
||||
if(j_it != J.end()){
|
||||
unsigned size = (*j_it)->get_size();
|
||||
segment xj = liveness_->get(*j_it);
|
||||
starts[*j_it] = w;
|
||||
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
|
||||
if(xh.start < xj.start)
|
||||
H.insert({w, segment{xh.start, xj.end}});
|
||||
if(xj.end < xh.end)
|
||||
H.insert({w, segment{xj.start, xh.end}});
|
||||
V.push_back(*j_it);
|
||||
J.erase(j_it);
|
||||
}
|
||||
}
|
||||
// Build interference graph
|
||||
std::map<shared_layout*, std::set<shared_layout*>> interferences;
|
||||
for(shared_layout* x: V)
|
||||
for(shared_layout* y: V){
|
||||
if(x == y)
|
||||
continue;
|
||||
unsigned X0 = starts[x], Y0 = starts[y];
|
||||
unsigned NX = x->get_size();
|
||||
unsigned NY = y->get_size();
|
||||
segment XS = {X0, X0 + NX};
|
||||
segment YS = {Y0, Y0 + NY};
|
||||
if(liveness_->get(x).intersect(liveness_->get(y))
|
||||
&& XS.intersect(YS))
|
||||
interferences[x].insert(y);
|
||||
}
|
||||
// Initialize colors
|
||||
std::map<shared_layout*, int> colors;
|
||||
for(shared_layout* X: V)
|
||||
colors[X] = (X==V[0])?0:-1;
|
||||
// First-fit graph coloring
|
||||
std::vector<bool> available(V.size());
|
||||
for(shared_layout* x: V){
|
||||
// Non-neighboring colors are available
|
||||
std::fill(available.begin(), available.end(), true);
|
||||
for(shared_layout* Y: interferences[x]){
|
||||
int color = colors[Y];
|
||||
if(color >= 0)
|
||||
available[color] = false;
|
||||
}
|
||||
// Assigns first available color
|
||||
auto It = std::find(available.begin(), available.end(), true);
|
||||
colors[x] = std::distance(available.begin(), It);
|
||||
}
|
||||
// Finalize allocation
|
||||
for(shared_layout* x: V){
|
||||
unsigned Adj = 0;
|
||||
for(shared_layout* y: interferences[x])
|
||||
Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
|
||||
offsets_[x] = starts[x] + colors[x] * Adj;
|
||||
}
|
||||
// Save maximum size of induced memory space
|
||||
allocated_size_ = 0;
|
||||
for(shared_layout* x: V)
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,162 +0,0 @@
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
axes::axes() {}
|
||||
|
||||
void axes::update_graph_reduce(ir::instruction *i) {
|
||||
auto* red = static_cast<ir::reduce_inst*>(i);
|
||||
unsigned axis = red->get_axis();
|
||||
ir::value *arg = red->get_operand(0);
|
||||
auto in_shapes = arg->get_type()->get_block_shapes();
|
||||
unsigned current = 0;
|
||||
for(unsigned d = 0; d < in_shapes.size(); d++){
|
||||
if(d == axis)
|
||||
continue;
|
||||
graph_.add_edge({i, current++}, {arg, d});
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_reshape(ir::instruction *i) {
|
||||
auto* reshape = static_cast<ir::reshape_inst*>(i);
|
||||
// operands
|
||||
ir::value *op = reshape->get_operand(0);
|
||||
// shapes
|
||||
auto op_shapes = op->get_type()->get_block_shapes();
|
||||
auto res_shapes = reshape->get_type()->get_block_shapes();
|
||||
// construct edges
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(unsigned d = 0; d < res_shapes.size(); d ++){
|
||||
bool same_shape = res_shapes[d] == op_shapes[current];
|
||||
// either add edge between axis or just add a node in the graph
|
||||
if(!is_skewed && same_shape)
|
||||
graph_.add_edge({i, d}, {op, current++});
|
||||
else
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
// reshaping is skewed
|
||||
if(res_shapes[d] > 1 && !same_shape)
|
||||
is_skewed = true;
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_trans(ir::instruction *i) {
|
||||
auto *trans = static_cast<ir::trans_inst*>(i);
|
||||
ir::value *op = trans->get_operand(0);
|
||||
auto perm = trans->get_perm();
|
||||
// add edge between axis perm[d] and axis d
|
||||
for(unsigned d = 0; d < perm.size(); d++)
|
||||
graph_.add_edge({i, perm[d]}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
|
||||
auto shapes = broadcast->get_type()->get_block_shapes();
|
||||
ir::value *op = broadcast->get_operand(0);
|
||||
ir::type *op_ty = op->get_type();
|
||||
const auto& op_shapes = op_ty->get_block_shapes();
|
||||
// add edge between non-broadcast axes
|
||||
for(unsigned d = 0; d < shapes.size(); d ++)
|
||||
if(op_shapes[d] == shapes[d])
|
||||
graph_.add_edge({i, d}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_dot(ir::instruction *i) {
|
||||
auto *dot = static_cast<ir::dot_inst*>(i);
|
||||
auto shapes = dot->get_type()->get_block_shapes();
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
ir::value *D = dot->get_operand(2);
|
||||
// add edges between result and accumulator
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
ir::value *op = i->get_operand(0);
|
||||
if(!op->get_type()->is_block_ty())
|
||||
return;
|
||||
auto rank = op->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++) {
|
||||
// If we are dealing with a masked async load we need to attach the
|
||||
// dimensions so we match the behaviour of the copy_to_shared instruction
|
||||
// which async masked load replaces.
|
||||
if (is_masked_load_async) {
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
}
|
||||
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()) {
|
||||
if(!is_masked_load_async && !i->get_type()->is_void_ty())
|
||||
graph_.add_edge({i, d}, {opx, d});
|
||||
graph_.add_edge({opx, d}, {opy, d});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_no_edge(ir::instruction *i) {
|
||||
if(!i->get_type()->is_block_ty())
|
||||
return;
|
||||
auto rank = i->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
}
|
||||
|
||||
void axes::update_graph(ir::instruction *i) {
|
||||
switch (i->get_id()) {
|
||||
case ir::INST_REDUCE: return update_graph_reduce(i);
|
||||
case ir::INST_RESHAPE: return update_graph_reshape(i);
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);
|
||||
case ir::INST_CAT: return update_graph_elementwise(i, true);
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true);
|
||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
|
||||
default: return update_graph_elementwise(i);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
int axes::get(ir::value *value, unsigned dim) {
|
||||
return axes_.at({value, dim});
|
||||
}
|
||||
|
||||
std::vector<int> axes::get(ir::value *value) {
|
||||
std::vector<int> result;
|
||||
for(size_t d = 0; d < value->get_type()->get_tile_rank(); d++)
|
||||
result.push_back(this->get(value, d));
|
||||
return result;
|
||||
}
|
||||
|
||||
void axes::run(ir::module &mod) {
|
||||
// make graph
|
||||
graph_.clear();
|
||||
axes_.clear();
|
||||
ir::for_each_instruction(mod, [this](ir::instruction *x) {
|
||||
update_graph(x);
|
||||
});
|
||||
// find connected components
|
||||
graph_.connected_components(nullptr, &axes_);
|
||||
std::set<size_t> uniq;
|
||||
for(auto x: axes_)
|
||||
uniq.insert(x.second);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@@ -1,653 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
// #include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
/* -------------------------------- *
|
||||
* Helper Functions *
|
||||
* -------------------------------- */
|
||||
|
||||
inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
|
||||
unsigned lo = std::min(a, b);
|
||||
unsigned hi = std::max(a, b);
|
||||
return std::min(std::max(x, lo), hi);
|
||||
}
|
||||
|
||||
inline bool is_hmma_c(ir::value *v, int sm){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
ir::type *a_ty = a->get_type();
|
||||
ir::value *b = x->get_operand(1);
|
||||
ir::type *b_ty = b->get_type();
|
||||
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
|
||||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
|
||||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
|
||||
x->allow_tf32() && sm >= 80) ||
|
||||
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
|
||||
sm >= 80);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
|
||||
mma_layout::TensorCoreType mma_type;
|
||||
if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
|
||||
ir::value* a = dot->get_operand(0);
|
||||
ir::value* b = dot->get_operand(1);
|
||||
ir::type* a_ty = a->get_type();
|
||||
ir::type* b_ty = b->get_type();
|
||||
ir::type* c_ty = v->get_type();
|
||||
|
||||
if (c_ty->get_scalar_ty()->is_fp32_ty()) {
|
||||
// floating point tensor cores
|
||||
if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
|
||||
mma_type = mma_layout::FP32_FP16_FP16_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
|
||||
mma_type = mma_layout::FP32_BF16_BF16_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
|
||||
&& dot->allow_tf32()) {
|
||||
mma_type = mma_layout::FP32_TF32_TF32_FP32;
|
||||
return mma_type;
|
||||
}
|
||||
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
|
||||
// throw std::runtime_error("integer tensor cores are not yet supported");
|
||||
// // integer tensor cores
|
||||
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
|
||||
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
|
||||
// return mma_type;
|
||||
// }
|
||||
// if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
|
||||
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
|
||||
// return mma_type;
|
||||
// }
|
||||
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
|
||||
mma_type = mma_layout::INT32_INT8_INT8_INT32;
|
||||
return mma_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
return mma_layout::NOT_APPLICABLE;
|
||||
}
|
||||
|
||||
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
if(i && i->get_pointer_operand() == v)
|
||||
result.insert(v);
|
||||
}
|
||||
}
|
||||
|
||||
inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && i->get_operand(n) == v)
|
||||
result = v;
|
||||
}
|
||||
}
|
||||
|
||||
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
|
||||
result = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
return true;
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
|
||||
bool result = true;
|
||||
for(ir::value *op: phi->ops())
|
||||
result = result && is_trans(op);
|
||||
return result;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Layout Visitor *
|
||||
* -------------------------------- */
|
||||
|
||||
void layout_visitor::visit_layout(data_layout *layout) {
|
||||
layout->accept(this);
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Base Data Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
data_layout::data_layout(id_t id,
|
||||
const std::vector<int> &axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
|
||||
// io pointer
|
||||
std::set<ir::value*> ptr;
|
||||
for(ir::value* v: values_)
|
||||
extract_io_use(v, ptr);
|
||||
order_.resize(axes_.size());
|
||||
std::iota(order_.begin(), order_.end(), 0);
|
||||
std::vector<unsigned> max_contiguous;
|
||||
for(ir::value* p: ptr){
|
||||
std::vector<unsigned> curr = align->contiguous(p);
|
||||
if(curr.size() > max_contiguous.size())
|
||||
max_contiguous = curr;
|
||||
else if(curr.size() == max_contiguous.size()){
|
||||
if(*std::max_element(curr.begin(), curr.end()) > *std::max_element(max_contiguous.begin(), max_contiguous.end()))
|
||||
max_contiguous = curr;
|
||||
}
|
||||
}
|
||||
if(max_contiguous.size() > 0){
|
||||
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
|
||||
return max_contiguous[a] > max_contiguous[b];
|
||||
});
|
||||
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
|
||||
// std::cout << order_[0] << " " << order_[1] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int data_layout::find_axis(int to_find) const {
|
||||
auto it = std::find(axes_.begin(), axes_.end(), to_find);
|
||||
if(it == axes_.end())
|
||||
return -1;
|
||||
return std::distance(axes_.begin(), it);
|
||||
}
|
||||
|
||||
|
||||
distributed_layout::distributed_layout(id_t id,
|
||||
const std::vector<int> &axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): data_layout(id, axes, shape, values, align)
|
||||
{ }
|
||||
|
||||
/* -------------------------------- *
|
||||
* MMA Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
mma_layout::mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target* tgt,
|
||||
shared_layout *layout_a, shared_layout *layout_b,
|
||||
ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
|
||||
tensor_core_type_ = get_mma_type(dot);
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
if(tgt->as_nvidia()->sm() < 80){
|
||||
fpw_ = {2, 2, 1};
|
||||
auto ord_a = layout_a->get_order();
|
||||
auto ord_b = layout_b->get_order();
|
||||
bool is_a_row = ord_a[0] != 0;
|
||||
bool is_b_row = ord_b[0] != 0;
|
||||
bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
|
||||
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
|
||||
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
|
||||
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
|
||||
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
||||
contig_per_thread_ = {1, 1};
|
||||
}
|
||||
else{
|
||||
// fpw_ = {1, 1, 1};
|
||||
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
|
||||
contig_per_thread_ = {1, 2};
|
||||
// rep_ = {2, 2, 1};
|
||||
}
|
||||
order_ = {0, 1};
|
||||
|
||||
/* warps per tile */
|
||||
wpt_ = {1, 1, 1};
|
||||
// try to make warp-level tiles as square as possible to maximize data re-use
|
||||
if (tgt->as_nvidia()->sm() < 80) {
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||
}while(wpt_nm1 != wpt_);
|
||||
} else {
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = false;
|
||||
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
|
||||
break;
|
||||
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
|
||||
if (wpt_[0] < shape_[0] / spw_[0]) {
|
||||
wpt_[0] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
} else {
|
||||
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
|
||||
wpt_[1] *= 2;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
} while (changed);
|
||||
}
|
||||
|
||||
/* shape per block */
|
||||
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Scanline Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
scanline_layout::scanline_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target *tgt): distributed_layout(SCANLINE, axes, shape, values, align){
|
||||
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
|
||||
unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1;
|
||||
nts_.resize(shape_.size());
|
||||
mts_.resize(shape_.size());
|
||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
||||
|
||||
std::vector<ir::value*> ptrs;
|
||||
for(ir::value *v: values)
|
||||
for(ir::user *usr: v->get_users())
|
||||
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
|
||||
if(ptrs.empty() || ptrs[0]->get_type()->get_tile_rank() <= io->get_pointer_operand()->get_type()->get_tile_rank())
|
||||
ptrs.push_back(io->get_pointer_operand());
|
||||
}
|
||||
|
||||
unsigned i = order_[0];
|
||||
int contiguous = 1;
|
||||
for(ir::value* ptr: ptrs){
|
||||
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
|
||||
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
|
||||
}
|
||||
|
||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
size /= shape_[i];
|
||||
num_threads /= mts_[i];
|
||||
if(is_dot)
|
||||
nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
|
||||
for(size_t d = 1; d < shape_.size(); d++){
|
||||
i = order_[d];
|
||||
if(d > 1 || !is_dot)
|
||||
nts_[i] = 1;
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
num_threads = num_threads / mts_[i];
|
||||
}
|
||||
|
||||
shape_per_cta_.resize(shape_.size());
|
||||
for(size_t d = 0; d < shape_.size(); d++)
|
||||
shape_per_cta_[d] = mts_[d]*nts_[d];
|
||||
}
|
||||
|
||||
|
||||
/* -------------------------------- *
|
||||
* Shared Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
if(phi->get_parent() != terminator->get_parent())
|
||||
return false;
|
||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||
return br->get_true_dest() == phi->get_parent()
|
||||
|| br->get_false_dest() == phi->get_parent();
|
||||
else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
|
||||
return false;
|
||||
else
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
|
||||
void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(v);
|
||||
if(!phi || phi->get_num_incoming() != 2)
|
||||
return;
|
||||
ir::basic_block *block_0 = phi->get_incoming_block(0);
|
||||
ir::basic_block *block_1 = phi->get_incoming_block(1);
|
||||
ir::instruction *terminator_0 = block_0->get_inst_list().back();
|
||||
ir::instruction *terminator_1 = block_1->get_inst_list().back();
|
||||
bool is_latch_0 = is_loop_latch(phi, terminator_0);
|
||||
bool is_latch_1 = is_loop_latch(phi, terminator_1);
|
||||
ir::value *value_0 = phi->get_incoming_value(0);
|
||||
ir::value *value_1 = phi->get_incoming_value(1);
|
||||
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
|
||||
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
|
||||
if(!(i_0 && !i_1) &&
|
||||
!(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
|
||||
!(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
|
||||
return;
|
||||
if(is_latch_1)
|
||||
res.reset(new double_buffer_info_t{value_0, value_1, phi});
|
||||
if(is_latch_0)
|
||||
res.reset(new double_buffer_info_t{value_1, value_0, phi});
|
||||
}
|
||||
|
||||
static bool is_smem(ir::value* v) {
|
||||
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v))
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
/// param:
|
||||
/// value_1: next_value
|
||||
static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::basic_block* bb1,
|
||||
std::vector<ir::value*>& values_0, ir::value*& value_1) {
|
||||
ir::value* next = phi;
|
||||
while (auto cphi = dynamic_cast<ir::phi_node*>(next)) {
|
||||
// smem from previous bb & phi/smem from current bb
|
||||
ir::value* c0 = cphi->get_incoming_value(0);
|
||||
ir::value* c1 = cphi->get_incoming_value(1);
|
||||
ir::basic_block *cbb0 = cphi->get_incoming_block(0);
|
||||
ir::basic_block *cbb1 = cphi->get_incoming_block(1);
|
||||
|
||||
if (is_smem(c0)) {
|
||||
assert(cbb0 == bb0);
|
||||
values_0.push_back(c0);
|
||||
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
|
||||
next = phi1;
|
||||
continue;
|
||||
} else {
|
||||
if (is_smem(c1)) {
|
||||
value_1 = c1;
|
||||
assert(cbb1 == bb1);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t> &res, int &prev_stages) {
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(v);
|
||||
// if the phi node is nested
|
||||
if (!phi)
|
||||
return;
|
||||
|
||||
ir::basic_block *bb0 = phi->get_incoming_block(0);
|
||||
ir::basic_block *bb1 = phi->get_incoming_block(1);
|
||||
|
||||
std::vector<ir::value*> values_0;
|
||||
ir::value* value_1;
|
||||
|
||||
if (!is_multistage_pipe_phi(phi, bb0, bb1, values_0, value_1))
|
||||
return;
|
||||
|
||||
// double-buffer is a special case
|
||||
if (values_0.size() == 1)
|
||||
return;
|
||||
|
||||
// compute original values_0 input order
|
||||
std::map<ir::value*, int> order;
|
||||
int idx = 0;
|
||||
for (ir::instruction* instr : *bb0) {
|
||||
if (std::find(values_0.begin(), values_0.end(), instr) != values_0.end())
|
||||
order[static_cast<ir::value*>(instr)] = idx++;
|
||||
}
|
||||
assert(order.size() == values_0.size() && "order size incorrect");
|
||||
|
||||
int curr_stages = values_0.size() + 1;
|
||||
if (curr_stages > prev_stages) {
|
||||
res.reset(new N_buffer_info_t{values_0, value_1, phi, order});
|
||||
prev_stages = curr_stages;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
shared_layout::shared_layout(data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
ir::type *ty,
|
||||
analysis::align* align, target *tgt)
|
||||
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
|
||||
|
||||
size_ = 0;
|
||||
arg_layout_ = arg;
|
||||
|
||||
// N-stage buffering
|
||||
int prev_stages = 0;
|
||||
for (ir::value *v : values)
|
||||
extract_N_bufferable(v, N_buffer_, prev_stages);
|
||||
|
||||
// double-buffering
|
||||
if (!N_buffer_)
|
||||
for(ir::value *v: values)
|
||||
extract_double_bufferable(v, double_buffer_);
|
||||
|
||||
// order
|
||||
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
|
||||
order_ = arg_order;
|
||||
|
||||
ir::value* dot_a = nullptr;
|
||||
ir::value* dot_b = nullptr;
|
||||
ir::value* hmma_dot_a = nullptr;
|
||||
ir::value* hmma_dot_b = nullptr;
|
||||
for(ir::value* v: values){
|
||||
extract_dot_use(v, dot_a, 0);
|
||||
extract_dot_use(v, dot_b, 1);
|
||||
extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
|
||||
extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
|
||||
}
|
||||
hmma_dot_a_ = hmma_dot_a;
|
||||
hmma_dot_b_ = hmma_dot_b;
|
||||
|
||||
// Update mma_vec
|
||||
if (hmma_dot_a_) {
|
||||
assert(order_.size() == 2);
|
||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
|
||||
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
||||
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
|
||||
|
||||
// for now, disable swizzle when using lds.8
|
||||
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||
if (order_[0] == 0) // need transpose
|
||||
allow_swizzle_ = false;
|
||||
} else if (hmma_dot_b_) {
|
||||
assert(order_.size() == 2);
|
||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
|
||||
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
|
||||
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
|
||||
|
||||
// for now, disable swizzle when using lds.8
|
||||
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||
if (order_[0] == 1) // need transpose
|
||||
allow_swizzle_ = false;
|
||||
}
|
||||
|
||||
// size
|
||||
size_ = ty_->get_primitive_size_in_bits() / 8;
|
||||
for(auto s: shape_)
|
||||
size_ *= s;
|
||||
if(double_buffer_)
|
||||
size_ *= 2;
|
||||
if (N_buffer_) {
|
||||
size_ *= (N_buffer_->firsts.size() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
int shared_layout::get_num_stages() const {
|
||||
if (double_buffer_)
|
||||
return 2;
|
||||
if (N_buffer_)
|
||||
return N_buffer_->firsts.size() + 1;
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t shared_layout::get_per_stage_elements() const {
|
||||
return get_per_stage_size()/(ty_->get_primitive_size_in_bits()/8);
|
||||
}
|
||||
|
||||
/* -------------------------------- *
|
||||
* ---- Layouts Inference Pass ---- *
|
||||
* -------------------------------- */
|
||||
|
||||
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt)
|
||||
: axes_(axes), align_(align), num_warps_(num_warps), tgt_(tgt){ }
|
||||
|
||||
|
||||
void layouts::connect(ir::value *x, ir::value *y) {
|
||||
if(x == y)
|
||||
return;
|
||||
if(!x->get_type()->is_block_ty())
|
||||
return;
|
||||
if(!y->get_type()->is_block_ty())
|
||||
return;
|
||||
std::vector<int> x_axes = axes_->get(x);
|
||||
std::vector<int> y_axes = axes_->get(y);
|
||||
std::set<int> sx_axes(x_axes.begin(), x_axes.end());
|
||||
std::set<int> sy_axes(y_axes.begin(), y_axes.end());
|
||||
std::set<int> common;
|
||||
std::set_intersection(sx_axes.begin(), sx_axes.end(),
|
||||
sy_axes.begin(), sy_axes.end(),
|
||||
std::inserter(common, common.begin()));
|
||||
graph_.add_edge(x, x);
|
||||
graph_.add_edge(y, y);
|
||||
if(!common.empty())
|
||||
graph_.add_edge(x, y);
|
||||
}
|
||||
|
||||
void layouts::make_graph(ir::instruction *i) {
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
connect(i, opx);
|
||||
connect(opx, opy);
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
// if(layouts_.find(id) != layouts_.end())
|
||||
// return;
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(),
|
||||
[&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
|
||||
auto cmp = [](ir::value* x, ir::value *y) {
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||
return xx < yy;
|
||||
};
|
||||
std::vector<ir::value*> lvalue = values;
|
||||
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
|
||||
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
|
||||
const auto& axes = axes_->get(largest);
|
||||
const auto& shapes = largest->get_type()->get_block_shapes();
|
||||
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
|
||||
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v);
|
||||
});
|
||||
// type
|
||||
if(it_hmma_c != values.end()){
|
||||
ir::instruction *dot = (ir::instruction*)*it_hmma_c;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
create(groups_.at(a), values_.at(groups_.at(a)));
|
||||
create(groups_.at(b), values_.at(groups_.at(b)));
|
||||
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
|
||||
(shared_layout*)layouts_.at(groups_.at(a)),
|
||||
(shared_layout*)layouts_.at(groups_.at(b)),
|
||||
dot);
|
||||
}
|
||||
else if(it_cts != values.end()){
|
||||
ir::instruction *cts = (ir::instruction*)*it_cts;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
}
|
||||
else{
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
// make graph
|
||||
graph_.clear();
|
||||
layouts_.clear();
|
||||
groups_.clear();
|
||||
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) {
|
||||
make_graph(i);
|
||||
});
|
||||
|
||||
|
||||
// connected components
|
||||
graph_.connected_components(&values_, &groups_);
|
||||
|
||||
// create layouts
|
||||
for(const auto& x: values_)
|
||||
create(x.first, x.second);
|
||||
|
||||
// create temporaries
|
||||
size_t id = values_.size();
|
||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||
id++;
|
||||
ir::value *arg = red->get_operand(0);
|
||||
unsigned axis = red->get_axis();
|
||||
// shape
|
||||
auto shapes = arg->get_type()->get_block_shapes();
|
||||
scanline_layout *layout = get(arg)->to_scanline();
|
||||
shapes[axis] = layout->mts(axis);
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[red] = id;
|
||||
}
|
||||
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
|
||||
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
|
||||
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
|
||||
id++;
|
||||
size_t dim = val->get_type()->get_tile_rank();
|
||||
ir::type::block_shapes_t shape(dim);
|
||||
for(size_t k = 0; k < dim; k++){
|
||||
shape[k] = std::max(in_layout->shape_per_cta(k),
|
||||
out_layout->shape_per_cta(k));
|
||||
}
|
||||
auto in_ord = in_layout->get_order();
|
||||
auto out_ord = out_layout->get_order();
|
||||
int in_vec = in_layout->contig_per_thread(in_ord[0]);
|
||||
int out_vec = out_layout->contig_per_thread(out_ord[0]);
|
||||
int pad = std::max(in_vec, out_vec);
|
||||
shape[out_ord[0]] += pad;
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[val] = id;
|
||||
}
|
||||
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
|
||||
id++;
|
||||
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
|
||||
tmp_[atom] = id;
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,59 +0,0 @@
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
void liveness::run(ir::module &mod) {
|
||||
intervals_.clear();
|
||||
|
||||
// Assigns index to each instruction
|
||||
std::map<ir::value*, slot_index> indices;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
slot_index index = 0;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()){
|
||||
index += 1;
|
||||
indices.insert({instr, index});
|
||||
}
|
||||
}
|
||||
|
||||
// create live intervals
|
||||
for(auto &x: layouts_->get_all()) {
|
||||
shared_layout* layout = x.second->to_shared();
|
||||
if(!layout)
|
||||
continue;
|
||||
// users
|
||||
std::set<ir::user*> users;
|
||||
for(ir::value *v: layout->get_values()){
|
||||
for(ir::user *u: v->get_users())
|
||||
users.insert(u);
|
||||
}
|
||||
// compute intervals
|
||||
unsigned start = INT32_MAX;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(indices.find(v) != indices.end())
|
||||
start = std::min(start, indices.at(v));
|
||||
unsigned end = 0;
|
||||
for(ir::user *u: users)
|
||||
if(indices.find(u) != indices.end())
|
||||
end = std::max(end, indices.at(u));
|
||||
if(end == 0)
|
||||
end = start + 1;
|
||||
intervals_[layout] = segment{start, end};
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,61 +0,0 @@
|
||||
#include "triton/codegen/analysis/swizzle.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
void swizzle::run(ir::module &) {
|
||||
per_phase_.clear();
|
||||
max_phase_.clear();
|
||||
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
|
||||
if(!layout)
|
||||
continue;
|
||||
ir::value* mma_dot_a = layout->hmma_dot_a();
|
||||
ir::value* mma_dot_b = layout->hmma_dot_b();
|
||||
|
||||
if(!mma_dot_a && !mma_dot_b){
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
continue;
|
||||
}
|
||||
auto ord = layout->get_order();
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
|
||||
int inner = mma_dot_a ? 0 : 1;
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||
if(mma_dot_a)
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
else
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
}
|
||||
else {
|
||||
if (!layout->allow_swizzle()) {
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
} else {
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
||||
vec_[layout] = layout->get_mma_vec();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,86 +0,0 @@
|
||||
#include "triton/codegen/pass.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/swizzle.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/disassociate.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||
int cc, int num_warps, int num_stages, int& shared_static) {
|
||||
// generate llvm code
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
// optimizations
|
||||
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::transform::cts cts(cts_use_async);
|
||||
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
|
||||
codegen::transform::disassociate disassociate;
|
||||
codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::swizzle swizzle(&layouts, target);
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole(target, &layouts);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::prefetch prefetch_s(target);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
|
||||
// run passes
|
||||
dce.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
pipeline.run(ir);
|
||||
dce.run(ir);
|
||||
disassociate.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu())
|
||||
cts.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
coalesce.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir);
|
||||
dce.run(ir);
|
||||
if (target->is_gpu())
|
||||
cts.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
peephole.run(ir);
|
||||
dce.run(ir);
|
||||
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||
swizzle.run(ir);
|
||||
liveness.run(ir);
|
||||
allocation.run(ir);
|
||||
prefetch_s.run(ir);
|
||||
barriers.run(ir);
|
||||
isel.visit(ir, *llvm);
|
||||
shared_static = allocation.allocated_size();
|
||||
return llvm;
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
File diff suppressed because it is too large
Load Diff
@@ -1,173 +0,0 @@
|
||||
#include "triton/codegen/target.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include <iostream>
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
// base
|
||||
|
||||
|
||||
nvidia_cu_target* target::as_nvidia() {
|
||||
return dynamic_cast<nvidia_cu_target*>(this);
|
||||
}
|
||||
|
||||
bool target::is_gpu() const {
|
||||
return is_gpu_;
|
||||
}
|
||||
|
||||
// AMD
|
||||
void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
|
||||
fn->setCallingConv(CallingConv::AMDGPU_KERNEL);
|
||||
}
|
||||
|
||||
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
|
||||
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {});
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* group_id = get_block_id(module, builder, ax);
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
|
||||
return result;
|
||||
}
|
||||
|
||||
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
|
||||
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::amdgcn_workgroup_id_x,
|
||||
Intrinsic::amdgcn_workgroup_id_y,
|
||||
Intrinsic::amdgcn_workgroup_id_z
|
||||
};
|
||||
Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {});
|
||||
return group_id;
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented on AMD");
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::amdgcn_workitem_id_x,
|
||||
Intrinsic::amdgcn_workitem_id_y,
|
||||
Intrinsic::amdgcn_workitem_id_z
|
||||
};
|
||||
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_local_id, {});
|
||||
}
|
||||
|
||||
// NVIDIA
|
||||
|
||||
void nvidia_cu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn){
|
||||
// set metadata
|
||||
Metadata *md_args[] = {
|
||||
ValueAsMetadata::get(fn),
|
||||
MDString::get(ctx, "kernel"),
|
||||
ValueAsMetadata::get(builder.getInt32(1))
|
||||
};
|
||||
module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
|
||||
}
|
||||
|
||||
Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_barrier0);
|
||||
return builder.CreateCall(barrier, {});
|
||||
}
|
||||
|
||||
Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl);
|
||||
return builder.CreateCall(barrier, {});
|
||||
}
|
||||
|
||||
|
||||
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* group_id = get_block_id(module, builder, ax);
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
|
||||
return result;
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> cta_ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
||||
};
|
||||
Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {});
|
||||
return cta_id;
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_tid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_tid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_tid_z
|
||||
};
|
||||
Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_local_id, {});
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
||||
};
|
||||
return builder.CreateIntrinsic(ids[ax], {}, {});
|
||||
}
|
||||
|
||||
// CPU
|
||||
|
||||
void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
|
||||
// normal cpu functions can be kernels
|
||||
}
|
||||
|
||||
Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
// no barrier on CPU
|
||||
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
|
||||
}
|
||||
|
||||
Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
// no barrier on CPU
|
||||
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
|
||||
}
|
||||
|
||||
|
||||
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
|
||||
const Function *fn = builder.GetInsertBlock()->getParent();
|
||||
size_t num_params = fn->getFunctionType()->getNumParams();
|
||||
static std::array<const Argument*, 3> ids = {
|
||||
fn->arg_begin() + num_params - 3,
|
||||
fn->arg_begin() + num_params - 2,
|
||||
fn->arg_begin() + num_params - 1
|
||||
};
|
||||
return (Argument*)ids[ax];
|
||||
}
|
||||
|
||||
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
|
||||
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
|
||||
return result;
|
||||
}
|
||||
|
||||
Value* cpu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
return builder.getInt32(0);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -1,133 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
|
||||
: align_(align), layout_(layouts) { }
|
||||
|
||||
|
||||
// simplify layout conversions using the following simple rules:
|
||||
// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
|
||||
// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y))
|
||||
//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){
|
||||
// ir::value* _op = inst->get_operand(0);
|
||||
// ir::instruction* op = dynamic_cast<ir::instruction*>(_op);
|
||||
// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma();
|
||||
// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma();
|
||||
// std::cout << 1 << std::endl;
|
||||
// // i must be layout conversion instruction
|
||||
// if(!mma_in && !mma_out)
|
||||
// return inst;
|
||||
// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
|
||||
// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT;
|
||||
// if((mma_in || mma_out) && is_op_cvt &&
|
||||
// (layout_->get(inst) == layout_->get(op->get_operand(0))))
|
||||
// return op->get_operand(0);
|
||||
// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y))
|
||||
// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR)
|
||||
// return inst;
|
||||
// std::cout << 1 << std::endl;
|
||||
// for(size_t i = 0; i < op->get_num_operands(); i++){
|
||||
// ir::value* arg_i = op->get_operand(i);
|
||||
// builder.set_insert_point(op);
|
||||
// // create new layout transform
|
||||
// ir::instruction* new_arg_i = inst->clone();
|
||||
// builder.insert(new_arg_i);
|
||||
// // set the right args
|
||||
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
|
||||
// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder));
|
||||
// }
|
||||
// std::cout << 2 << std::endl;
|
||||
// return op;
|
||||
//}
|
||||
|
||||
void coalesce::run(ir::module &mod) {
|
||||
ir::builder& builder = mod.get_builder();
|
||||
// add layout conversion instructions
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
// coalesce before store
|
||||
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
|
||||
if(ir::value* op = i->get_operand(1))
|
||||
if(op->get_type()->is_block_ty())
|
||||
if(layout_->get(op)->to_mma()){
|
||||
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
|
||||
builder.set_insert_point(i);
|
||||
builder.insert(new_op);
|
||||
i->replace_uses_of_with(op, new_op);
|
||||
}
|
||||
// uncoalesce after load
|
||||
if(auto x = dynamic_cast<ir::load_inst*>(i))
|
||||
if(x->get_type()->is_block_ty())
|
||||
if(x->get_type()->get_tile_rank()==2)
|
||||
if(layout_->get(x)->to_mma()){
|
||||
builder.set_insert_point_after(x);
|
||||
ir::instruction* new_x = ir::cvt_layout_inst::create(x);
|
||||
builder.insert(new_x);
|
||||
x->replace_all_uses_with(new_x);
|
||||
new_x->replace_uses_of_with(new_x, x);
|
||||
}
|
||||
}
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
// re-arrange scanline to promote memory coalescing
|
||||
if(auto x = dynamic_cast<ir::store_inst*>(i)){
|
||||
ir::value* ptr = x->get_pointer_operand();
|
||||
ir::value* val = x->get_value_operand();
|
||||
auto out_contig = align_->contiguous(ptr);
|
||||
auto val_inst = dynamic_cast<ir::instruction*>(val);
|
||||
if(!val_inst)
|
||||
break;
|
||||
if(dynamic_cast<ir::cvt_layout_inst*>(val))
|
||||
break;
|
||||
std::vector<unsigned> in_contig;
|
||||
std::vector<ir::instruction*> queue = {val_inst};
|
||||
std::set<ir::instruction*> seen;
|
||||
std::vector<ir::io_inst*> ios;
|
||||
while(!queue.empty()){
|
||||
ir::instruction* curr = queue.back();
|
||||
seen.insert(curr);
|
||||
queue.pop_back();
|
||||
if(auto dot_inst = dynamic_cast<ir::dot_inst*>(curr))
|
||||
break;
|
||||
if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
|
||||
in_contig = align_->contiguous(io_inst->get_pointer_operand());
|
||||
break;
|
||||
}
|
||||
for(ir::value* op: curr->ops()){
|
||||
auto inst_op = dynamic_cast<ir::instruction*>(op);
|
||||
if(!inst_op || seen.find(inst_op) != seen.end())
|
||||
continue;
|
||||
if(!op->get_type()->is_block_ty() ||
|
||||
!val->get_type()->is_block_ty())
|
||||
continue;
|
||||
if(op->get_type()->get_tile_num_elements() ==
|
||||
val->get_type()->get_tile_num_elements())
|
||||
queue.push_back(inst_op);
|
||||
}
|
||||
}
|
||||
if(in_contig.size() <= 1 || out_contig==in_contig)
|
||||
continue;
|
||||
builder.set_insert_point_after(val_inst);
|
||||
auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst));
|
||||
x->replace_uses_of_with(val_inst, new_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,97 +0,0 @@
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
inline bool is_shmem_op(ir::instruction* i, int op) {
|
||||
if(i->get_id() == ir::INST_DOT)
|
||||
return op==0 || op==1;
|
||||
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
||||
return op==0;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return op==0;
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool is_shmem_res(ir::value* v){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return false;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// run pass on module
|
||||
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction
|
||||
if(!i) {
|
||||
builder.set_insert_point(parent);
|
||||
ir::value *copy;
|
||||
if(to_shared)
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
return;
|
||||
}
|
||||
// phi node
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
|
||||
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
|
||||
return;
|
||||
}
|
||||
// already in shared memory
|
||||
if(to_shared && is_shmem_res(i))
|
||||
return;
|
||||
// copy
|
||||
builder.set_insert_point_after(i);
|
||||
ir::value *copy;
|
||||
if(to_shared){
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
}
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
}
|
||||
|
||||
void cts::run(ir::module &mod) {
|
||||
// Add shared copies
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function* fn: mod.get_function_list()){
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
size_t num_op = i->get_num_operands();
|
||||
// copy to shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(is_shmem_op(i, k)){
|
||||
add_copy(i, i->get_operand(k), builder, true);
|
||||
}
|
||||
// copy from shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(!dynamic_cast<ir::phi_node*>(i) &&
|
||||
!is_shmem_op(i,k) &&
|
||||
is_shmem_res(i->get_operand(k))){
|
||||
add_copy(i, i->get_operand(k), builder, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,75 +0,0 @@
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
void dce::run(ir::module &mod) {
|
||||
std::list<ir::instruction*> work_list;
|
||||
std::set<ir::instruction*> marked;
|
||||
|
||||
// initialize work-list
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
switch(i->get_id()){
|
||||
case ir::INST_RETURN:
|
||||
case ir::INST_UNCOND_BRANCH:
|
||||
case ir::INST_COND_BRANCH:
|
||||
case ir::INST_UNMASKED_STORE:
|
||||
case ir::INST_MASKED_STORE:
|
||||
case ir::INST_ATOMIC_CAS:
|
||||
case ir::INST_ATOMIC_RMW:
|
||||
case ir::INST_ATOMIC_EXCH:
|
||||
case ir::INST_BARRIER: {
|
||||
work_list.push_back(i);
|
||||
marked.insert(i);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mark -- ignore branches
|
||||
while(!work_list.empty()){
|
||||
ir::instruction* current = work_list.back();
|
||||
work_list.pop_back();
|
||||
// mark instruction operands
|
||||
for(ir::value* op: current->ops()) {
|
||||
if(auto *i = dynamic_cast<ir::instruction*>(op)){
|
||||
if(marked.insert(i).second)
|
||||
work_list.push_back(i);
|
||||
}
|
||||
}
|
||||
// TODO: mark last intstruction of current's reverse-dominance frontier
|
||||
}
|
||||
|
||||
// sweep -- delete non-branch unmarked instructions
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(marked.find(i) == marked.end())
|
||||
to_delete.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// delete
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,62 +0,0 @@
|
||||
#include "triton/codegen/transform/disassociate.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
|
||||
std::set<ir::value*>& seen) {
|
||||
if (dynamic_cast<ir::phi_node*>(root))
|
||||
return root;
|
||||
if(!seen.insert(root).second)
|
||||
return root;
|
||||
if(!root->get_type()->is_block_ty())
|
||||
return root;
|
||||
|
||||
bld.set_insert_point(root);
|
||||
ir::instruction *new_root = bld.insert(root->clone());
|
||||
for(ir::value *op: root->ops()){
|
||||
ir::instruction *i = dynamic_cast<ir::instruction*>(op);
|
||||
if(!i || i->get_id() == ir::INST_REDUCE)
|
||||
continue;
|
||||
ir::instruction* new_op = rematerialize(bld, i, seen);
|
||||
new_root->replace_uses_of_with(op, new_op);
|
||||
}
|
||||
return new_root;
|
||||
}
|
||||
|
||||
void disassociate::run(ir::module &mod) {
|
||||
ir::builder &bld = mod.get_builder();
|
||||
|
||||
// ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
// bld.set_insert_point(i);
|
||||
// for(ir::value* op: i->ops()){
|
||||
// auto reshape = dynamic_cast<ir::make_range*>(op);
|
||||
// if(!reshape)
|
||||
// continue;
|
||||
// ir::instruction* new_op = bld.insert(reshape->clone());
|
||||
// i->replace_uses_of_with(op, new_op);
|
||||
// }
|
||||
// });
|
||||
|
||||
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(dynamic_cast<ir::reshape_inst*>(i) || dynamic_cast<ir::splat_inst*>(i)){
|
||||
std::set<ir::value*> seen;
|
||||
ir::instruction* new_i = rematerialize(bld, i, seen);
|
||||
i->replace_all_uses_with(new_i);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,244 +0,0 @@
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
|
||||
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
|
||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||
if (analysis::double_buffer_info_t* info = layout->get_double_buffer())
|
||||
return group_of(info->first, async_write);
|
||||
else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) {
|
||||
if (v == info->phi)
|
||||
return group_of(info->firsts[0], async_write);
|
||||
else // prefetched value
|
||||
return group_of(info->firsts[1], async_write);
|
||||
}
|
||||
std::vector<int> groups(phi->get_num_operands());
|
||||
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||
return *std::max_element(groups.begin(), groups.end());
|
||||
}
|
||||
else{
|
||||
if(layouts_->has_tmp(v))
|
||||
return async_write.size() - 1;
|
||||
auto it = std::find(async_write.begin(), async_write.end(), v);
|
||||
return std::distance(async_write.begin(), it);
|
||||
}
|
||||
}
|
||||
|
||||
inline bool membar::intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout) {
|
||||
if(!a_layout || !b_layout)
|
||||
return false;
|
||||
int a_start = alloc_->offset(a_layout);
|
||||
int a_end = a_start + a_layout->get_size();
|
||||
int b_start = alloc_->offset(b_layout);
|
||||
int b_end = b_start + b_layout->get_size();
|
||||
if(a_start < b_end || b_start < a_end)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
|
||||
val_set_t ret;
|
||||
for(ir::value* a: as){
|
||||
if(!a->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
|
||||
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
|
||||
for(ir::value* b: bs){
|
||||
if(!b->get_type()->is_block_ty())
|
||||
continue;
|
||||
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
|
||||
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
|
||||
if(intersect_with(a_layout, b_layout) ||
|
||||
intersect_with(a_layout, b_tmp) ||
|
||||
intersect_with(a_tmp, b_layout) ||
|
||||
intersect_with(a_tmp, b_tmp))
|
||||
ret.insert(b);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool membar::check_safe_war(ir::instruction* i) {
|
||||
bool is_i_shared_block = i->get_type()->is_block_ty() &&
|
||||
layouts_->get(i)->to_shared();
|
||||
bool is_i_double_buffered = is_i_shared_block &&
|
||||
layouts_->get(i)->to_shared()->get_double_buffer();
|
||||
bool is_i_n_buffered = is_i_shared_block &&
|
||||
layouts_->get(i)->to_shared()->get_N_buffer();
|
||||
|
||||
if (is_i_double_buffered || is_i_n_buffered) {
|
||||
// with async copy & prefetch_s disabled, WARs are not safe
|
||||
if (dynamic_cast<ir::masked_load_async_inst*>(i) && !prefetch_->is_prefetched(i))
|
||||
return false;
|
||||
else
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void membar::transfer(ir::basic_block *block,
|
||||
val_vec_t& async_write,
|
||||
val_set_t& sync_write,
|
||||
val_set_t& sync_read,
|
||||
std::set<ir::value*>& safe_war,
|
||||
bool& inserted, ir::builder& builder) {
|
||||
std::vector<ir::async_wait_inst*> async_waits;
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
for(ir::instruction *i: instructions){
|
||||
if(dynamic_cast<ir::phi_node*>(i))
|
||||
continue;
|
||||
if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
|
||||
dynamic_cast<ir::masked_load_async_inst*>(i)){
|
||||
async_write.push_back(i);
|
||||
}
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
|
||||
sync_write.insert(i);
|
||||
ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
|
||||
ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
|
||||
// Get shared memory reads
|
||||
std::set<ir::value*> read;
|
||||
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
|
||||
[&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();});
|
||||
if(layouts_->has_tmp(i))
|
||||
read.insert(i);
|
||||
// RAW (async)
|
||||
val_set_t tmp;
|
||||
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
|
||||
if(intersect_with(read, tmp).size()){
|
||||
std::vector<int> groups(read.size());
|
||||
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||
int N = *std::max_element(groups.begin(), groups.end());
|
||||
if(N < async_write.size()){
|
||||
builder.set_insert_point(i);
|
||||
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
|
||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||
inserted = true;
|
||||
async_waits.push_back(async_wait);
|
||||
}
|
||||
}
|
||||
// RAW, WAR
|
||||
bool is_safe_war = check_safe_war(i);
|
||||
// WAR barrier is not required when data is double-buffered
|
||||
if(!intersect_with(read, sync_write).empty() ||
|
||||
(!intersect_with({i}, sync_read).empty() && !is_safe_war)) {
|
||||
builder.set_insert_point(i);
|
||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||
inserted = true;
|
||||
}
|
||||
// update state of asynchronous copies
|
||||
if(async_wait){
|
||||
int N = async_write.size() - async_wait->get_N();
|
||||
async_write.erase(async_write.begin(), async_write.begin() + N);
|
||||
}
|
||||
// all the copy_to_shared and read from shared are synchronized after barrier
|
||||
if(barrier){
|
||||
sync_write.clear();
|
||||
sync_read.clear();
|
||||
}
|
||||
sync_read.insert(read.begin(), read.end());
|
||||
}
|
||||
|
||||
// coalesce barriers
|
||||
// fixme: to support more general cases
|
||||
if (async_waits.size() == 2) {
|
||||
// (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;)
|
||||
for (int idx=0; idx<async_waits.size()-1; ++idx) {
|
||||
ir::async_wait_inst *first_async_wait = async_waits[idx];
|
||||
std::vector<ir::instruction*> to_erase;
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){
|
||||
ir::instruction *i = *iter;
|
||||
if (static_cast<ir::instruction*>(first_async_wait) == i) {
|
||||
// peak next 5 instructions
|
||||
auto peak_iter = std::next(iter);
|
||||
if (std::distance(peak_iter, instructions.end()) >= 5) {
|
||||
auto first_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
|
||||
auto first_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter++);
|
||||
auto second_async_wait = dynamic_cast<ir::async_wait_inst*>(*peak_iter++);
|
||||
auto second_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
|
||||
auto second_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter);
|
||||
if (first_bar && first_pf && second_async_wait && second_bar && second_pf) {
|
||||
int first_n = first_async_wait->get_N();
|
||||
int second_n = second_async_wait->get_N();
|
||||
to_erase.push_back(second_async_wait);
|
||||
to_erase.push_back(second_bar);
|
||||
first_async_wait->set_N(second_n);
|
||||
}
|
||||
} else
|
||||
break;
|
||||
for (ir::instruction *i : to_erase)
|
||||
block->erase(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void membar::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// extract phi-node associates with double-buffered
|
||||
// shared-memory copies. These can be read from and written to
|
||||
// without needing synchronization
|
||||
std::set<ir::value*> safe_war;
|
||||
for(const auto& x: layouts_->get_all()){
|
||||
analysis::shared_layout* layout = x.second->to_shared();
|
||||
if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer())
|
||||
continue;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(v != layout->get_double_buffer()->phi){
|
||||
safe_war.insert(v);
|
||||
}
|
||||
}
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, val_vec_t> async_writes;
|
||||
std::map<ir::basic_block*, val_set_t> sync_writes;
|
||||
std::map<ir::basic_block*, val_set_t> sync_reads;
|
||||
std::list<ir::value *> pipelined;
|
||||
bool inserted;
|
||||
do{
|
||||
inserted = false;
|
||||
// find barrier location
|
||||
for(ir::basic_block *block: rpo){
|
||||
// join inputs
|
||||
val_vec_t async_write;
|
||||
val_set_t sync_write;
|
||||
val_set_t sync_read;
|
||||
val_set_t tmp;
|
||||
for(ir::basic_block* pred: block->get_predecessors()){
|
||||
for(ir::value* v: async_writes[pred])
|
||||
if(tmp.insert(v).second)
|
||||
async_write.push_back(v);
|
||||
sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
|
||||
sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
|
||||
}
|
||||
transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
|
||||
async_writes[block] = async_write;
|
||||
sync_writes[block] = sync_write;
|
||||
sync_reads[block] = sync_read;
|
||||
}
|
||||
}while(inserted);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,309 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
|
||||
const std::vector<int>& perm) {
|
||||
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
|
||||
// transpose operands
|
||||
std::vector<ir::value*> incs;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm));
|
||||
// create phi for transposed values
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
result->add_incoming(incs[n], phi->get_incoming_block(n));
|
||||
return result;
|
||||
}
|
||||
else if(auto i = dynamic_cast<ir::instruction*>(value)){
|
||||
ir::basic_block* block = i->get_parent();
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
it++;
|
||||
builder.set_insert_point(it);
|
||||
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
|
||||
trans->set_operand(0, i);
|
||||
return trans;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
|
||||
auto trans = dynamic_cast<ir::trans_inst*>(value);
|
||||
if(!trans)
|
||||
return false;
|
||||
auto users = trans->get_users();
|
||||
auto ops = trans->ops();
|
||||
if(users.size() > 1 || ops.size() > 1)
|
||||
return false;
|
||||
ir::value* op = *ops.begin();
|
||||
// trans(phi) -> phi(trans(), trans()...)
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(op);
|
||||
if(!phi)
|
||||
return false;
|
||||
ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm());
|
||||
if(!new_phi)
|
||||
return false;
|
||||
trans->replace_all_uses_with(new_phi);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
// dot(a, b, c) + d -> dot(a, b, c + d)
|
||||
// d + dot(a, b, c) -> dot(a, b, c + d)
|
||||
auto add = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
|
||||
bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
|
||||
ir::value *lhs = add->get_operand(0);
|
||||
ir::value *rhs = add->get_operand(1);
|
||||
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
|
||||
ir::dot_inst *rhs_dot = dynamic_cast<ir::dot_inst*>(rhs);
|
||||
if(!lhs_dot && !rhs_dot)
|
||||
return false;
|
||||
ir::dot_inst *dot = lhs_dot ? lhs_dot : rhs_dot;
|
||||
ir::value *other = (dot == lhs) ? rhs : lhs;
|
||||
ir::value *acc = dot->get_operand(2);
|
||||
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
|
||||
ir::constant *_0 = nullptr;
|
||||
if(splat)
|
||||
_0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
|
||||
if(!_0)
|
||||
return false;
|
||||
if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
|
||||
if (fp_0->get_value() != 0.0)
|
||||
return false;
|
||||
if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
|
||||
if (int_0->get_value() != 0)
|
||||
return false;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
builder.set_insert_point(add);
|
||||
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
|
||||
add->replace_all_uses_with(new_dot);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
//bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
|
||||
// auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
|
||||
// if(cfs) {
|
||||
// ir::value *arg = cfs->get_operand(0);
|
||||
// ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
|
||||
// if(!cts)
|
||||
// return false;
|
||||
// cfs->replace_all_uses_with(cts->get_operand(0));
|
||||
// return true;
|
||||
// }
|
||||
|
||||
//}
|
||||
|
||||
bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){
|
||||
auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value);
|
||||
if(!copy_to_shared)
|
||||
return false;
|
||||
ir::value *arg = copy_to_shared->get_operand(0);
|
||||
ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg);
|
||||
if(!ld)
|
||||
return false;
|
||||
builder.set_insert_point(copy_to_shared);
|
||||
ir::value *ptr = ld->get_pointer_operand();
|
||||
ir::value *msk = ld->get_mask_operand();
|
||||
ir::value *val = ld->get_false_value_operand();
|
||||
analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
|
||||
int nts = layout->nts(layout->get_order()[0]);
|
||||
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(nts*dtsize >= 4){
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
|
||||
copy_to_shared->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
|
||||
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
|
||||
// return true;
|
||||
|
||||
}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||
if(!x)
|
||||
return false;
|
||||
ir::value *arg = x->get_operand(0);
|
||||
auto shapes = arg->get_type()->get_block_shapes();
|
||||
if(shapes[x->get_axis()] == 1){
|
||||
builder.set_insert_point(x);
|
||||
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes());
|
||||
x->replace_all_uses_with(new_red);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
|
||||
auto binop = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
|
||||
ir::value *lhs = binop->get_operand(0);
|
||||
ir::value *rhs = binop->get_operand(1);
|
||||
ir::constant_int *_1_lhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_lhs = cst;
|
||||
}
|
||||
ir::constant_int *_1_rhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
|
||||
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(cst && cst->get_value() == 1)
|
||||
_1_rhs = cst;
|
||||
}
|
||||
if(_1_lhs){
|
||||
binop->replace_all_uses_with(rhs);
|
||||
return true;
|
||||
}
|
||||
else if(_1_rhs){
|
||||
binop->replace_all_uses_with(lhs);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
|
||||
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
|
||||
if(!x)
|
||||
return false;
|
||||
auto y = dynamic_cast<ir::getelementptr_inst*>(x->get_pointer_operand());
|
||||
if(!y)
|
||||
return false;
|
||||
auto idx = *y->idx_begin();
|
||||
auto z = dynamic_cast<ir::binary_operator*>(idx);
|
||||
if(!z)
|
||||
return false;
|
||||
bool is_sub = z->get_op() == ir::binary_op_t::Sub;
|
||||
auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
|
||||
bool is_lhs_0 = lhs && (lhs->get_value()==0);
|
||||
bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();
|
||||
if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){
|
||||
x->replace_all_uses_with(y->get_pointer_operand());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& builder){
|
||||
auto select = dynamic_cast<ir::select_inst*>(value);
|
||||
if(!select)
|
||||
return false;
|
||||
auto if_value = dynamic_cast<ir::masked_load_inst*>(select->get_if_value_op());
|
||||
if(!if_value)
|
||||
return false;
|
||||
if(select->get_pred_op() != if_value->get_mask_operand())
|
||||
return false;
|
||||
builder.set_insert_point(select);
|
||||
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
|
||||
if_value->get_mask_operand(),
|
||||
select->get_else_value_op(),
|
||||
if_value->get_cache_modifier(),
|
||||
if_value->get_eviction_policy(),
|
||||
if_value->get_is_volatile());
|
||||
select->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){
|
||||
auto cvt = dynamic_cast<ir::cvt_layout_inst*>(value);
|
||||
if(!cvt)
|
||||
return false;
|
||||
ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0));
|
||||
if(!op)
|
||||
return false;
|
||||
// // convert(elementwise(x, y)) = elementwise(convert(x), convert(y))
|
||||
// if(op->get_id() == ir::INST_BINOP){
|
||||
// for(size_t i = 0; i < op->get_num_operands(); i++){
|
||||
// ir::value* arg_i = op->get_operand(i);
|
||||
// builder.set_insert_point(op);
|
||||
// // create new layout transform
|
||||
// ir::instruction* new_arg_i = cvt->clone();
|
||||
// layouts_->copy(new_arg_i, op);
|
||||
// builder.insert(new_arg_i);
|
||||
// // set the right args
|
||||
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
|
||||
// op->replace_uses_of_with(arg_i, new_arg_i);
|
||||
// }
|
||||
// cvt->replace_all_uses_with(op);
|
||||
// return true;
|
||||
// }
|
||||
auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op);
|
||||
if(!cvt_op)
|
||||
return false;
|
||||
// convert1(convert2(x)) if convert1 is the inverse of convert2
|
||||
ir::value* op_op = cvt_op->get_operand(0);
|
||||
if(layouts_->has(cvt) && layouts_->has(op_op) &&
|
||||
layouts_->get(cvt) && layouts_->get(op_op)){
|
||||
cvt->replace_all_uses_with(op_op);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void peephole::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// keep track of whether any modification was made
|
||||
std::set<ir::value*> seen;
|
||||
size_t n_seen;
|
||||
|
||||
// rewrite dots first
|
||||
do{
|
||||
n_seen = seen.size();
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(seen.find(i) != seen.end())
|
||||
continue;
|
||||
bool was_modified = rewrite_dot(i, builder);
|
||||
if(was_modified){
|
||||
seen.insert(i);
|
||||
}
|
||||
}
|
||||
}while(seen.size() != n_seen);
|
||||
|
||||
// rewrite other ops
|
||||
seen.clear();
|
||||
do{
|
||||
n_seen = seen.size();
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(seen.find(i) != seen.end())
|
||||
continue;
|
||||
bool was_modified = false;
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
// was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
// TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD
|
||||
// was_modified = was_modified || rewrite_select_masked_load(i, builder);
|
||||
was_modified = was_modified || rewrite_cvt_layout(i, builder);
|
||||
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
||||
if(was_modified)
|
||||
seen.insert(i);
|
||||
}
|
||||
}while(seen.size() != n_seen);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,330 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i || i->get_parent() != block)
|
||||
return;
|
||||
if(i->get_id()==ir::INST_PHI)
|
||||
return;
|
||||
ret.push_back(i);
|
||||
for(ir::user* u: i->get_users())
|
||||
recursive_deps(u, block, ret);
|
||||
}
|
||||
|
||||
void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
|
||||
auto instr = dynamic_cast<ir::instruction*>(cond);
|
||||
for (auto op : instr->ops()) {
|
||||
if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) {
|
||||
phis.insert(phi_op);
|
||||
return;
|
||||
}
|
||||
if (dynamic_cast<ir::instruction*>(op))
|
||||
get_induction_vars(op, phis);
|
||||
}
|
||||
}
|
||||
|
||||
/// assume incoming block is 1
|
||||
ir::value* rematerialize_vals(ir::builder& builder, ir::basic_block* block, ir::value* v,
|
||||
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i || i->get_parent() != block)
|
||||
return v;
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
|
||||
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
|
||||
throw std::runtime_error("Don't have that phi node\n");
|
||||
return prev_phi_vals.at(phi);
|
||||
}
|
||||
|
||||
std::vector<ir::value*> new_ops;
|
||||
for(ir::value* op: i->ops()){
|
||||
new_ops.push_back(rematerialize_vals(builder, block, op, prev_phi_vals));
|
||||
}
|
||||
ir::instruction* ret = i->clone();
|
||||
for(size_t k = 0; k < new_ops.size(); k++)
|
||||
ret->set_operand(k, new_ops[k]);
|
||||
builder.insert(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
ir::value* rematerialize(ir::builder& builder, ir::basic_block* block,
|
||||
ir::value* v, size_t phi_idx){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i || i->get_parent() != block)
|
||||
return v;
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
|
||||
return phi->get_incoming_value(phi_idx);
|
||||
|
||||
std::vector<ir::value*> new_ops;
|
||||
for(ir::value* op: i->ops()){
|
||||
new_ops.push_back(rematerialize(builder, block, op, phi_idx));
|
||||
}
|
||||
ir::instruction* ret = i->clone();
|
||||
for(size_t k = 0; k < new_ops.size(); k++)
|
||||
ret->set_operand(k, new_ops[k]);
|
||||
builder.insert(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// moving the prev phi vals to the next iteration
|
||||
std::map<ir::phi_node*, ir::value*> update_prev_phi_vals(
|
||||
ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
|
||||
std::map<ir::phi_node*, ir::value*> next_phi_vals;
|
||||
for (auto &[phi, val] : prev_phi_vals) {
|
||||
next_phi_vals[phi] = rematerialize_vals(builder, block, phi->get_incoming_value(1), prev_phi_vals);
|
||||
}
|
||||
return next_phi_vals;
|
||||
}
|
||||
|
||||
void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& load_ivs,
|
||||
std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
|
||||
for (auto& [phi, val] : load_ivs) {
|
||||
if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
|
||||
ir::value* next_k = rematerialize_vals(builder, block, phi->get_incoming_value(1), load_ivs);
|
||||
assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
|
||||
new_phi->add_incoming(next_k, phi->get_incoming_block(1));
|
||||
// cache next_k (to be used by next_mask)
|
||||
next_load_ivs[phi] = next_k;
|
||||
} else
|
||||
throw std::runtime_error("must be phi");
|
||||
}
|
||||
}
|
||||
|
||||
struct pipeline_info_t {
|
||||
ir::load_inst* load;
|
||||
ir::phi_node* ptr;
|
||||
ir::dot_inst* dot;
|
||||
|
||||
pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot)
|
||||
: load(load), ptr(ptr), dot(dot) {}
|
||||
};
|
||||
|
||||
void pipeline::run(ir::module &mod) {
|
||||
if (num_stages_ <= 1)
|
||||
return;
|
||||
// *Very* conservative heuristics for pre-fetching.
|
||||
// A load instruction can be pipelined if:
|
||||
// - the pointer is a phi node that references a value
|
||||
// in its basic block (i.e., pointer induction variable)
|
||||
// - the load has only a single use in a dot instruction
|
||||
// As more use cases become apparent, this pass will be improved
|
||||
std::vector<pipeline_info_t> to_pipeline;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
|
||||
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
|
||||
auto users = load->get_users();
|
||||
auto dot = dynamic_cast<ir::dot_inst*>(*users.begin());
|
||||
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
|
||||
&& users.size() == 1 && dot)
|
||||
to_pipeline.push_back({load, ptr, dot});
|
||||
}});
|
||||
// do the pipelining
|
||||
std::vector<ir::phi_node*> new_loads;
|
||||
ir::builder &builder = mod.get_builder();
|
||||
const int num_stages = num_stages_;
|
||||
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
||||
for(auto info: to_pipeline){
|
||||
ir::load_inst* load = info.load;
|
||||
ir::phi_node* ptr = info.ptr;
|
||||
ir::basic_block* block = load->get_parent();
|
||||
ir::basic_block* header = block->get_predecessors()[0];
|
||||
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
|
||||
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
|
||||
assert(block_br);
|
||||
assert(header_br);
|
||||
ir::type* ty = load->get_type();
|
||||
// multi-stage pipe
|
||||
if (has_copy_async_ && num_stages > 2) {
|
||||
ir::value* header_cond = header_br->get_cond();
|
||||
ir::value* block_cond = block_br->get_cond();
|
||||
// 1. collect induction variables
|
||||
std::set<ir::phi_node*> induction_vars;
|
||||
get_induction_vars(block_cond, induction_vars);
|
||||
|
||||
std::vector<ir::value*> first_ptrs(num_stages-1);
|
||||
std::vector<ir::value*> first_loads(num_stages-1);
|
||||
std::vector<ir::value*> first_masks(num_stages-1);
|
||||
std::vector<ir::value*> loop_conds(num_stages-1);
|
||||
|
||||
std::map<ir::phi_node*, ir::value*> prev_phi_vals;
|
||||
// initialize prev_phi_vals
|
||||
// Add all phi nodes. The following DCE pass will delete dead ones.
|
||||
for (ir::instruction *instr : block->get_inst_list())
|
||||
if (auto *phi = dynamic_cast<ir::phi_node*>(instr))
|
||||
if (phi->get_incoming_block(1) == block)
|
||||
prev_phi_vals[phi] = phi->get_value_for_block(header);
|
||||
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
first_ptrs[0] = ptr->get_value_for_block(header);
|
||||
loop_conds[0] = header_cond;
|
||||
first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
|
||||
ir::value* false_value = nullptr;
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask =rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals) ;
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||
first_masks[0] = builder.create_and(first_masks[0], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
} else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
|
||||
for (int stage = 1; stage < num_stages-1; ++stage) {
|
||||
// mask is the loop condition of the previous iteration
|
||||
loop_conds[stage] = rematerialize_vals(builder, block, block_cond, prev_phi_vals);
|
||||
prev_phi_vals = update_prev_phi_vals(builder, block, prev_phi_vals);
|
||||
first_ptrs[stage] = rematerialize_vals(builder, block, ptr, prev_phi_vals);
|
||||
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals);
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
}
|
||||
|
||||
// create new phis for induction variables
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
std::map<ir::phi_node*, ir::value*> load_ivs;
|
||||
std::map<ir::phi_node*, ir::value*> next_load_ivs;
|
||||
for (auto& [iv, val] : prev_phi_vals) {
|
||||
ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
|
||||
pn->add_incoming(prev_phi_vals[iv], header);
|
||||
load_ivs[iv] = pn;
|
||||
}
|
||||
// add incoming for phis & update next_load_ivs
|
||||
finalize_iv_vals(builder, block, load_ivs, next_load_ivs);
|
||||
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
// ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_ptr = rematerialize_vals(builder, block, ptr->get_value_for_block(block), load_ivs);
|
||||
ir::value* next_mask = builder.create_splat(
|
||||
rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes());
|
||||
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), next_load_ivs);
|
||||
// TODO: false may depends on some other phi nodes
|
||||
ir::value* remat_false_value =
|
||||
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), next_load_ivs);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
|
||||
|
||||
// phi node
|
||||
ptr->set_incoming_value(0, first_ptrs.back());
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
// nested phis for load
|
||||
std::vector<ir::phi_node*> new_load_phis(num_stages-1);
|
||||
for (auto& pn : new_load_phis)
|
||||
pn = builder.create_phi(ty, 2);
|
||||
for (int i=0; i<num_stages-2; ++i) {
|
||||
new_load_phis[i]->add_incoming(first_loads[i], header);
|
||||
new_load_phis[i]->add_incoming(new_load_phis[i+1], block);
|
||||
}
|
||||
new_load_phis.back()->add_incoming(first_loads.back(), header);
|
||||
new_load_phis.back()->add_incoming(next_load, block);
|
||||
load->replace_all_uses_with(new_load_phis.front());
|
||||
new_loads.push_back(new_load_phis.back());
|
||||
|
||||
// record first_loads to reorder them
|
||||
preheader_loads.push_back({new_load_phis.front(), first_loads});
|
||||
} else {
|
||||
// pre-fetch first iteration
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
ir::value* first_ptr = ptr->get_value_for_block(header);
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
|
||||
ir::value* false_value;
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 0);
|
||||
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 0);
|
||||
first_mask = builder.create_and(first_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 1);
|
||||
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 1);
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||
new_load->add_incoming(first_load, header);
|
||||
new_load->add_incoming(next_load, block);
|
||||
load->replace_all_uses_with(new_load);
|
||||
new_loads.push_back(new_load);
|
||||
}
|
||||
}
|
||||
|
||||
// try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to
|
||||
// a0, b0, a1, b1, ...
|
||||
if (!preheader_loads.empty()) {
|
||||
ir::basic_block* header = preheader_loads.begin()->first->get_incoming_block(0);
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
for (int i=1; i<num_stages-1; ++i) {
|
||||
for (auto iter = preheader_loads.begin(); iter != preheader_loads.end(); ++iter) {
|
||||
ir::instruction* original_load = static_cast<ir::instruction*>(iter->second.at(i));
|
||||
ir::instruction* moved_load = original_load->clone();
|
||||
builder.insert(moved_load);
|
||||
original_load->replace_all_uses_with(moved_load);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// try to move dot_inst after loads
|
||||
// for better overlap of io and compute
|
||||
struct move_config_t{
|
||||
std::vector<ir::instruction*> insts;
|
||||
ir::load_inst* dst;
|
||||
};
|
||||
std::vector<move_config_t> to_move(to_pipeline.size());
|
||||
|
||||
if(has_copy_async_){
|
||||
for (size_t idx = 0; idx < to_pipeline.size(); ++idx) {
|
||||
auto info = to_pipeline[idx];
|
||||
ir::load_inst* load = info.load;
|
||||
ir::phi_node* ptr = info.ptr;
|
||||
ir::dot_inst* dot = info.dot;
|
||||
ir::basic_block* bb = dot->get_parent();
|
||||
recursive_deps(dot, bb, to_move[idx].insts);
|
||||
to_move[idx].dst = load;
|
||||
}
|
||||
|
||||
for(auto& move_config: to_move){
|
||||
builder.set_insert_point_after(move_config.dst);
|
||||
for(ir::instruction* i: move_config.insts){
|
||||
i->get_parent()->erase(i);
|
||||
builder.insert(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,133 +0,0 @@
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
namespace triton::codegen::transform {
|
||||
|
||||
/// find defs till phis
|
||||
static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::instruction*> &ret) {
|
||||
ir::instruction *i = dynamic_cast<ir::instruction*>(v);
|
||||
if (!i || i->get_parent() != bb)
|
||||
return;
|
||||
if (i->get_id() == ir::INST_PHI)
|
||||
return;
|
||||
ret.push_back(i);
|
||||
for (ir::value *op : i->ops())
|
||||
recursive_defs(op, bb, ret);
|
||||
}
|
||||
|
||||
void prefetch::run(ir::module &mod) {
|
||||
// 1. collect dots that can be prefethced
|
||||
std::vector<ir::dot_inst*> to_prefetch;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
||||
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
||||
// Now only do prefetching when dot is using tensor cores
|
||||
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
|
||||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
|
||||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
|
||||
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||
|
||||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8)
|
||||
&& dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8)
|
||||
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
)
|
||||
)
|
||||
return;
|
||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||
if (a && a->get_incoming_block(1) == a->get_parent() &&
|
||||
b && b->get_incoming_block(1) == b->get_parent())
|
||||
to_prefetch.push_back(dot);
|
||||
}
|
||||
});
|
||||
|
||||
assert(to_prefetch.size() <=1 && "Don't know what to do with multiple dots");
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// 2. do the prefetching
|
||||
for (ir::dot_inst* dot : to_prefetch) {
|
||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||
assert(a->get_incoming_block(0) == b->get_incoming_block(0));
|
||||
ir::basic_block *loop_header = a->get_incoming_block(0);
|
||||
ir::basic_block *loop_body = a->get_parent();
|
||||
|
||||
// mark as prefetched
|
||||
dot->set_prefetched(true);
|
||||
|
||||
// 1. in the loop header (first iteration)
|
||||
builder.set_insert_point(loop_header->get_inst_list().back());
|
||||
assert(a && b);
|
||||
builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
|
||||
builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
|
||||
|
||||
// 2. at the end of the loop body (next iteration)
|
||||
builder.set_insert_point(loop_body->get_inst_list().back());
|
||||
builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
|
||||
builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
|
||||
|
||||
prefetched_vals_.insert(a->get_incoming_value(0));
|
||||
prefetched_vals_.insert(b->get_incoming_value(0));
|
||||
// nested phis
|
||||
ir::value* next_a = a->get_incoming_value(1);
|
||||
while (auto* next_a_phi = dynamic_cast<ir::phi_node*>(next_a)) {
|
||||
prefetched_vals_.insert(next_a_phi->get_incoming_value(0));
|
||||
next_a = next_a_phi->get_incoming_value(1);
|
||||
}
|
||||
prefetched_vals_.insert(next_a);
|
||||
|
||||
ir::value* next_b = b->get_incoming_value(1);
|
||||
while (auto* next_b_phi = dynamic_cast<ir::phi_node*>(next_b)) {
|
||||
prefetched_vals_.insert(next_b_phi->get_incoming_value(0));
|
||||
next_b = next_b_phi->get_incoming_value(1);
|
||||
}
|
||||
prefetched_vals_.insert(next_b);
|
||||
}
|
||||
|
||||
// move loads to the beginning of the loop
|
||||
if (tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80) {
|
||||
for (ir::function *fn : mod.get_function_list())
|
||||
for (ir::basic_block *bb : fn->blocks()) {
|
||||
// only apply to loop body
|
||||
if (bb->get_predecessors().size() != 2 || bb->get_predecessors()[1] != bb)
|
||||
continue;
|
||||
// record loads (& dependency) to move
|
||||
std::vector<ir::instruction*> loads;
|
||||
// record original inst order
|
||||
std::map<ir::instruction*, size_t> idx_map;
|
||||
size_t idx = 0;
|
||||
for (ir::instruction *inst : bb->get_inst_list()) {
|
||||
if (auto *i = dynamic_cast<ir::masked_load_inst*>(inst))
|
||||
recursive_defs(i, bb, loads);
|
||||
idx_map[inst] = idx;
|
||||
idx++;
|
||||
}
|
||||
|
||||
// remove duplicates & keep the original input order
|
||||
std::sort(loads.begin(), loads.end());
|
||||
loads.erase(std::unique(loads.begin(), loads.end()), loads.end());
|
||||
std::sort(loads.begin(), loads.end(), [&idx_map](ir::instruction *a, ir::instruction *b) {
|
||||
return idx_map[a] < idx_map[b];
|
||||
});
|
||||
|
||||
builder.set_insert_point(bb->get_first_non_phi());
|
||||
auto& inst_list = bb->get_inst_list();
|
||||
for (ir::instruction *i : loads){
|
||||
auto it = std::find(inst_list.begin(), inst_list.end(), i);
|
||||
// make sure we don't invalidate insert point
|
||||
// in case instruction already at the top
|
||||
if(it == builder.get_insert_point())
|
||||
continue;
|
||||
bb->erase(i);
|
||||
builder.insert(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace triton::codegen::transform
|
@@ -1,51 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void reorder::run(ir::module& mod){
|
||||
// ir::builder &builder = mod.get_builder();
|
||||
// std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
|
||||
|
||||
// for(ir::function *fn: mod.get_function_list())
|
||||
// for(ir::basic_block *block: fn->blocks())
|
||||
// for(ir::instruction* i: block->get_inst_list()){
|
||||
// if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
|
||||
// ir::value* _ptr = ld->get_pointer_operand();
|
||||
// ir::value* _msk = ld->get_mask_operand();
|
||||
// ir::value* _val = ld->get_false_value_operand();
|
||||
// auto ptr = std::find(block->begin(), block->end(), _ptr);
|
||||
// auto msk = std::find(block->begin(), block->end(), _msk);
|
||||
// auto val = std::find(block->begin(), block->end(), _val);
|
||||
// if(ptr == block->end() || msk == block->end() || val == block->end())
|
||||
// continue;
|
||||
// auto it = std::find(block->begin(), block->end(), i);
|
||||
// int dist_ptr = std::distance(ptr, it);
|
||||
// int dist_msk = std::distance(msk, it);
|
||||
// int dist_val = std::distance(val, it);
|
||||
// if(dist_ptr < dist_msk && dist_ptr < dist_val)
|
||||
// builder.set_insert_point(++ptr);
|
||||
// if(dist_msk < dist_ptr && dist_msk < dist_val)
|
||||
// builder.set_insert_point(++msk);
|
||||
// if(dist_val < dist_ptr && dist_val < dist_msk)
|
||||
// builder.set_insert_point(++val);
|
||||
// ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
|
||||
// to_replace.push_back(std::make_pair(ld, new_ld));
|
||||
// }
|
||||
// }
|
||||
|
||||
// for(auto& x: to_replace)
|
||||
// x.first->replace_all_uses_with(x.second);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user