[codegen] rough template for axis_info pass
This commit is contained in:
@@ -30,7 +30,7 @@ void matmul(restrict read_only align(4) fp16 *A,
|
|||||||
restrict read_only align(4) fp16 *B,
|
restrict read_only align(4) fp16 *B,
|
||||||
align(4) fp32 *C,
|
align(4) fp32 *C,
|
||||||
int32 M, int32 N, int32 K,
|
int32 M, int32 N, int32 K,
|
||||||
int32 lda, int32 ldb, int32 ldc,
|
multiple_of(4) int32 lda, multiple_of(4) int32 ldb, multiple_of(4) int32 ldc,
|
||||||
int32 *locks, int32 grid0, int32 grid1) {
|
int32 *locks, int32 grid0, int32 grid1) {
|
||||||
int32 rxa[TM] = get_global_range[TM](0);
|
int32 rxa[TM] = get_global_range[TM](0);
|
||||||
int32 ryb[TN] = get_global_range[TN](1);
|
int32 ryb[TN] = get_global_range[TN](1);
|
||||||
|
39
include/triton/codegen/axis_info.h
Normal file
39
include/triton/codegen/axis_info.h
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
#ifndef TDL_INCLUDE_CODEGEN_AXIS_INFO_PASS_H
|
||||||
|
#define TDL_INCLUDE_CODEGEN_AXIS_INFO_PASS_H
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
namespace ir {
|
||||||
|
class value;
|
||||||
|
class module;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace codegen{
|
||||||
|
|
||||||
|
class axis_info {
|
||||||
|
private:
|
||||||
|
// helpers
|
||||||
|
bool is_first_axis_unit(ir::value *x);
|
||||||
|
|
||||||
|
// populate maps
|
||||||
|
bool populate_is_constant(ir::value *i);
|
||||||
|
unsigned populate_max_contiguous(ir::value *i);
|
||||||
|
unsigned populate_multiple_of(ir::value *i);
|
||||||
|
|
||||||
|
public:
|
||||||
|
void run(ir::module &mod);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::map<ir::value*, bool> is_constant_;
|
||||||
|
std::map<ir::value*, unsigned> max_contiguous_;
|
||||||
|
std::map<ir::value*, unsigned> multiple_of_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -5,6 +5,7 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include "value.h"
|
#include "value.h"
|
||||||
#include "constant.h"
|
#include "constant.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace ir{
|
namespace ir{
|
||||||
@@ -21,6 +22,8 @@ class argument: public value{
|
|||||||
public:
|
public:
|
||||||
static argument* create(type *ty, const std::string &name,
|
static argument* create(type *ty, const std::string &name,
|
||||||
function *parent = nullptr, unsigned arg_no = 0);
|
function *parent = nullptr, unsigned arg_no = 0);
|
||||||
|
function* get_parent() const;
|
||||||
|
unsigned get_arg_no() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
function *parent_;
|
function *parent_;
|
||||||
@@ -53,6 +56,10 @@ public:
|
|||||||
return value_;
|
return value_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_llvm_attr() const {
|
||||||
|
return kind_ != multiple_of;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
attribute_kind_t kind_;
|
attribute_kind_t kind_;
|
||||||
unsigned value_;
|
unsigned value_;
|
||||||
@@ -89,6 +96,7 @@ public:
|
|||||||
// attributes
|
// attributes
|
||||||
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
|
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
|
||||||
const attr_map_t &attrs() { return attrs_; }
|
const attr_map_t &attrs() { return attrs_; }
|
||||||
|
std::set<attribute> get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
module *parent_;
|
module *parent_;
|
||||||
|
@@ -122,6 +122,12 @@ public:
|
|||||||
bool is_int_div_rem() const;
|
bool is_int_div_rem() const;
|
||||||
bool is_shift() const;
|
bool is_shift() const;
|
||||||
bool is_cast() const;
|
bool is_cast() const;
|
||||||
|
bool is_int_mult() const;
|
||||||
|
bool is_int_add_sub() const;
|
||||||
|
bool is_int_div() const;
|
||||||
|
bool is_int_rem() const;
|
||||||
|
bool is_shl() const;
|
||||||
|
bool is_shr() const;
|
||||||
|
|
||||||
// Wraps
|
// Wraps
|
||||||
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
||||||
|
@@ -17,6 +17,7 @@
|
|||||||
#include "triton/codegen/shmem_liveness.h"
|
#include "triton/codegen/shmem_liveness.h"
|
||||||
#include "triton/codegen/shmem_info.h"
|
#include "triton/codegen/shmem_info.h"
|
||||||
#include "triton/codegen/shmem_barriers.h"
|
#include "triton/codegen/shmem_barriers.h"
|
||||||
|
#include "triton/codegen/axis_info.h"
|
||||||
#include "triton/codegen/target.h"
|
#include "triton/codegen/target.h"
|
||||||
#include "triton/codegen/vectorize.h"
|
#include "triton/codegen/vectorize.h"
|
||||||
#include <functional>
|
#include <functional>
|
||||||
@@ -60,11 +61,13 @@ public:
|
|||||||
optimize_dot(&tune),
|
optimize_dot(&tune),
|
||||||
optimize_cse(),
|
optimize_cse(),
|
||||||
optimize_trans(),
|
optimize_trans(),
|
||||||
|
axis_info(),
|
||||||
target_(target) { }
|
target_(target) { }
|
||||||
|
|
||||||
void target_independent(ir::module &module) {
|
void target_independent(ir::module &module) {
|
||||||
optimize_dot.run(module);
|
optimize_dot.run(module);
|
||||||
optimize_trans.run(module);
|
optimize_trans.run(module);
|
||||||
|
axis_info.run(module);
|
||||||
// ir::print(module, std::cout);
|
// ir::print(module, std::cout);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,6 +91,7 @@ public:
|
|||||||
codegen::optimize_dot optimize_dot;
|
codegen::optimize_dot optimize_dot;
|
||||||
codegen::optimize_cse optimize_cse;
|
codegen::optimize_cse optimize_cse;
|
||||||
codegen::optimize_trans optimize_trans;
|
codegen::optimize_trans optimize_trans;
|
||||||
|
codegen::axis_info axis_info;
|
||||||
codegen::target* target_;
|
codegen::target* target_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
129
lib/codegen/axis_info.cpp
Normal file
129
lib/codegen/axis_info.cpp
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
#include "triton/codegen/axis_info.h"
|
||||||
|
#include "triton/ir/module.h"
|
||||||
|
#include "triton/ir/function.h"
|
||||||
|
#include "triton/ir/basic_block.h"
|
||||||
|
#include "triton/ir/instructions.h"
|
||||||
|
#include "triton/ir/type.h"
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
namespace codegen{
|
||||||
|
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
||||||
|
return map.insert(std::make_pair(i, value)).first->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool axis_info::is_first_axis_unit(ir::value *x){
|
||||||
|
if(x->get_type()->is_tile_ty())
|
||||||
|
return x->get_type()->get_tile_shapes()[0]->get_value() == 1;
|
||||||
|
else
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool axis_info::populate_is_constant(ir::value *v) {
|
||||||
|
// helper for the cache
|
||||||
|
auto cache = [this,v](bool value){ return add_to_cache(v, value, is_constant_); };
|
||||||
|
// populate
|
||||||
|
if(v->get_type()->is_tile_ty()){
|
||||||
|
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||||
|
bool value = populate_is_constant(x->get_operand(0));
|
||||||
|
// check if broadcast (i.e., constant) along contiguous dimension
|
||||||
|
if(is_first_axis_unit(x->get_operand(0))
|
||||||
|
&& !is_first_axis_unit(x))
|
||||||
|
return cache(value);
|
||||||
|
}
|
||||||
|
// otherwise the tile is not constant in the contiguous dimension
|
||||||
|
return cache(false);
|
||||||
|
}
|
||||||
|
// scalars are always constant in the contiguous dimension
|
||||||
|
return cache(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned axis_info::populate_max_contiguous(ir::value *v){
|
||||||
|
// helper for the cache
|
||||||
|
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); };
|
||||||
|
// populate
|
||||||
|
if(v->get_type()->is_tile_ty()){
|
||||||
|
auto shapes = v->get_type()->get_tile_shapes();
|
||||||
|
if(dynamic_cast<ir::get_global_range_inst*>(v))
|
||||||
|
return cache(shapes[0]->get_value());
|
||||||
|
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||||
|
ir::value* lhs = x->get_operand(0);
|
||||||
|
ir::value* rhs = x->get_operand(1);
|
||||||
|
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||||
|
bool lhs_has_cst = populate_is_constant(lhs);
|
||||||
|
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||||
|
bool rhs_has_cst = populate_is_constant(rhs);
|
||||||
|
if(x->is_int_add_sub()){
|
||||||
|
if(lhs_has_cst)
|
||||||
|
return cache(rhs_max_contiguous);
|
||||||
|
if(rhs_has_cst)
|
||||||
|
return cache(lhs_max_contiguous);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cache(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned axis_info::populate_multiple_of(ir::value *v){
|
||||||
|
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); };
|
||||||
|
|
||||||
|
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||||
|
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||||
|
for(auto attr: attributes){
|
||||||
|
if(attr.get_kind() == ir::multiple_of)
|
||||||
|
return cache(attr.get_value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||||
|
int lhs = populate_multiple_of(x->get_operand(0));
|
||||||
|
int rhs = populate_multiple_of(x->get_operand(1));
|
||||||
|
if(x->is_int_mult())
|
||||||
|
return cache(lhs * rhs);
|
||||||
|
if(x->is_int_add_sub())
|
||||||
|
return cache(std::min(lhs, rhs));
|
||||||
|
if(x->is_int_div())
|
||||||
|
return cache(std::max(lhs / rhs, 1));
|
||||||
|
if(x->is_int_rem())
|
||||||
|
return cache(std::max(lhs % rhs, 1));
|
||||||
|
if(x->is_shl())
|
||||||
|
return cache(lhs << rhs);
|
||||||
|
if(x->is_shr())
|
||||||
|
return cache(std::max(lhs >> rhs, 1));
|
||||||
|
}
|
||||||
|
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||||
|
return cache(populate_multiple_of(x->get_operand(0)));
|
||||||
|
}
|
||||||
|
return cache(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void axis_info::run(ir::module &mod) {
|
||||||
|
// populate constant
|
||||||
|
for(ir::function *fn: mod.get_function_list())
|
||||||
|
for(ir::basic_block *block: fn->blocks())
|
||||||
|
for(ir::instruction *i: block->get_inst_list()){
|
||||||
|
populate_is_constant(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// populate multiple_of
|
||||||
|
for(ir::function *fn: mod.get_function_list())
|
||||||
|
for(ir::basic_block *block: fn->blocks())
|
||||||
|
for(ir::instruction *i: block->get_inst_list()){
|
||||||
|
populate_multiple_of(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// populate maximum contiguous
|
||||||
|
for(ir::function *fn: mod.get_function_list())
|
||||||
|
for(ir::basic_block *block: fn->blocks())
|
||||||
|
for(ir::instruction *i: block->get_inst_list()){
|
||||||
|
populate_max_contiguous(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
@@ -1125,6 +1125,7 @@ void selection::run(ir::module &src, Module &dst) {
|
|||||||
for(auto attr_pair: fn->attrs()){
|
for(auto attr_pair: fn->attrs()){
|
||||||
unsigned id = attr_pair.first;
|
unsigned id = attr_pair.first;
|
||||||
for(ir::attribute attr: attr_pair.second)
|
for(ir::attribute attr: attr_pair.second)
|
||||||
|
if(attr.is_llvm_attr())
|
||||||
dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr));
|
dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr));
|
||||||
}
|
}
|
||||||
tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
|
tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
|
||||||
|
@@ -16,6 +16,15 @@ argument *argument::create(type *ty, const std::string &name,
|
|||||||
return new argument(ty, name, parent, arg_no);
|
return new argument(ty, name, parent, arg_no);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function* argument::get_parent() const {
|
||||||
|
return parent_;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned argument::get_arg_no() const {
|
||||||
|
return arg_no_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/* function */
|
/* function */
|
||||||
function::function(function_type *ty, linkage_types_t linkage,
|
function::function(function_type *ty, linkage_types_t linkage,
|
||||||
const std::string &name, module *parent)
|
const std::string &name, module *parent)
|
||||||
|
@@ -109,6 +109,31 @@ std::string binary_operator::repr_impl() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool binary_operator::is_int_div() const {
|
||||||
|
return op_ == llop::UDiv || op_ == llop::SDiv;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool binary_operator::is_int_rem() const {
|
||||||
|
return op_ == llop::URem || op_ == llop::SRem;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool binary_operator::is_shl() const {
|
||||||
|
return op_ == llop::Shl;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool binary_operator::is_shr() const {
|
||||||
|
return op_ == llop::LShr || op_ == llop::AShr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool binary_operator::is_int_mult() const {
|
||||||
|
return op_ == llop::Mul;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool binary_operator::is_int_add_sub() const {
|
||||||
|
return op_ == llop::Add || llop::Sub;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
binary_operator::binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
|
binary_operator::binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
|
||||||
: instruction(ty, 2, 1, name, next), op_(op){
|
: instruction(ty, 2, 1, name, next), op_(op){
|
||||||
set_operand(0, lhs);
|
set_operand(0, lhs);
|
||||||
|
Reference in New Issue
Block a user