[codegen] rough template for axis_info pass
This commit is contained in:
@@ -30,7 +30,7 @@ void matmul(restrict read_only align(4) fp16 *A,
|
||||
restrict read_only align(4) fp16 *B,
|
||||
align(4) fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 lda, int32 ldb, int32 ldc,
|
||||
multiple_of(4) int32 lda, multiple_of(4) int32 ldb, multiple_of(4) int32 ldc,
|
||||
int32 *locks, int32 grid0, int32 grid1) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
|
39
include/triton/codegen/axis_info.h
Normal file
39
include/triton/codegen/axis_info.h
Normal file
@@ -0,0 +1,39 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_AXIS_INFO_PASS_H
|
||||
#define TDL_INCLUDE_CODEGEN_AXIS_INFO_PASS_H
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class value;
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class axis_info {
|
||||
private:
|
||||
// helpers
|
||||
bool is_first_axis_unit(ir::value *x);
|
||||
|
||||
// populate maps
|
||||
bool populate_is_constant(ir::value *i);
|
||||
unsigned populate_max_contiguous(ir::value *i);
|
||||
unsigned populate_multiple_of(ir::value *i);
|
||||
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
std::map<ir::value*, bool> is_constant_;
|
||||
std::map<ir::value*, unsigned> max_contiguous_;
|
||||
std::map<ir::value*, unsigned> multiple_of_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -5,6 +5,7 @@
|
||||
#include <map>
|
||||
#include "value.h"
|
||||
#include "constant.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
@@ -21,6 +22,8 @@ class argument: public value{
|
||||
public:
|
||||
static argument* create(type *ty, const std::string &name,
|
||||
function *parent = nullptr, unsigned arg_no = 0);
|
||||
function* get_parent() const;
|
||||
unsigned get_arg_no() const;
|
||||
|
||||
private:
|
||||
function *parent_;
|
||||
@@ -53,6 +56,10 @@ public:
|
||||
return value_;
|
||||
}
|
||||
|
||||
bool is_llvm_attr() const {
|
||||
return kind_ != multiple_of;
|
||||
}
|
||||
|
||||
private:
|
||||
attribute_kind_t kind_;
|
||||
unsigned value_;
|
||||
@@ -89,6 +96,7 @@ public:
|
||||
// attributes
|
||||
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
|
||||
const attr_map_t &attrs() { return attrs_; }
|
||||
std::set<attribute> get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
|
||||
|
||||
private:
|
||||
module *parent_;
|
||||
|
@@ -122,6 +122,12 @@ public:
|
||||
bool is_int_div_rem() const;
|
||||
bool is_shift() const;
|
||||
bool is_cast() const;
|
||||
bool is_int_mult() const;
|
||||
bool is_int_add_sub() const;
|
||||
bool is_int_div() const;
|
||||
bool is_int_rem() const;
|
||||
bool is_shl() const;
|
||||
bool is_shr() const;
|
||||
|
||||
// Wraps
|
||||
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
||||
|
@@ -17,6 +17,7 @@
|
||||
#include "triton/codegen/shmem_liveness.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/shmem_barriers.h"
|
||||
#include "triton/codegen/axis_info.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include <functional>
|
||||
@@ -60,11 +61,13 @@ public:
|
||||
optimize_dot(&tune),
|
||||
optimize_cse(),
|
||||
optimize_trans(),
|
||||
axis_info(),
|
||||
target_(target) { }
|
||||
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
axis_info.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
@@ -88,6 +91,7 @@ public:
|
||||
codegen::optimize_dot optimize_dot;
|
||||
codegen::optimize_cse optimize_cse;
|
||||
codegen::optimize_trans optimize_trans;
|
||||
codegen::axis_info axis_info;
|
||||
codegen::target* target_;
|
||||
};
|
||||
|
||||
|
129
lib/codegen/axis_info.cpp
Normal file
129
lib/codegen/axis_info.cpp
Normal file
@@ -0,0 +1,129 @@
|
||||
#include "triton/codegen/axis_info.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
|
||||
|
||||
template<class T>
|
||||
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
||||
return map.insert(std::make_pair(i, value)).first->second;
|
||||
}
|
||||
|
||||
|
||||
bool axis_info::is_first_axis_unit(ir::value *x){
|
||||
if(x->get_type()->is_tile_ty())
|
||||
return x->get_type()->get_tile_shapes()[0]->get_value() == 1;
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
bool axis_info::populate_is_constant(ir::value *v) {
|
||||
// helper for the cache
|
||||
auto cache = [this,v](bool value){ return add_to_cache(v, value, is_constant_); };
|
||||
// populate
|
||||
if(v->get_type()->is_tile_ty()){
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
bool value = populate_is_constant(x->get_operand(0));
|
||||
// check if broadcast (i.e., constant) along contiguous dimension
|
||||
if(is_first_axis_unit(x->get_operand(0))
|
||||
&& !is_first_axis_unit(x))
|
||||
return cache(value);
|
||||
}
|
||||
// otherwise the tile is not constant in the contiguous dimension
|
||||
return cache(false);
|
||||
}
|
||||
// scalars are always constant in the contiguous dimension
|
||||
return cache(true);
|
||||
}
|
||||
|
||||
unsigned axis_info::populate_max_contiguous(ir::value *v){
|
||||
// helper for the cache
|
||||
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); };
|
||||
// populate
|
||||
if(v->get_type()->is_tile_ty()){
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::get_global_range_inst*>(v))
|
||||
return cache(shapes[0]->get_value());
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
bool lhs_has_cst = populate_is_constant(lhs);
|
||||
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
bool rhs_has_cst = populate_is_constant(rhs);
|
||||
if(x->is_int_add_sub()){
|
||||
if(lhs_has_cst)
|
||||
return cache(rhs_max_contiguous);
|
||||
if(rhs_has_cst)
|
||||
return cache(lhs_max_contiguous);
|
||||
}
|
||||
}
|
||||
}
|
||||
return cache(1);
|
||||
}
|
||||
|
||||
unsigned axis_info::populate_multiple_of(ir::value *v){
|
||||
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); };
|
||||
|
||||
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||
for(auto attr: attributes){
|
||||
if(attr.get_kind() == ir::multiple_of)
|
||||
return cache(attr.get_value());
|
||||
}
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||
int lhs = populate_multiple_of(x->get_operand(0));
|
||||
int rhs = populate_multiple_of(x->get_operand(1));
|
||||
if(x->is_int_mult())
|
||||
return cache(lhs * rhs);
|
||||
if(x->is_int_add_sub())
|
||||
return cache(std::min(lhs, rhs));
|
||||
if(x->is_int_div())
|
||||
return cache(std::max(lhs / rhs, 1));
|
||||
if(x->is_int_rem())
|
||||
return cache(std::max(lhs % rhs, 1));
|
||||
if(x->is_shl())
|
||||
return cache(lhs << rhs);
|
||||
if(x->is_shr())
|
||||
return cache(std::max(lhs >> rhs, 1));
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
return cache(populate_multiple_of(x->get_operand(0)));
|
||||
}
|
||||
return cache(1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void axis_info::run(ir::module &mod) {
|
||||
// populate constant
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_is_constant(i);
|
||||
}
|
||||
|
||||
// populate multiple_of
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_multiple_of(i);
|
||||
}
|
||||
|
||||
// populate maximum contiguous
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_max_contiguous(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1125,6 +1125,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
for(auto attr_pair: fn->attrs()){
|
||||
unsigned id = attr_pair.first;
|
||||
for(ir::attribute attr: attr_pair.second)
|
||||
if(attr.is_llvm_attr())
|
||||
dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr));
|
||||
}
|
||||
tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
|
||||
|
@@ -16,6 +16,15 @@ argument *argument::create(type *ty, const std::string &name,
|
||||
return new argument(ty, name, parent, arg_no);
|
||||
}
|
||||
|
||||
function* argument::get_parent() const {
|
||||
return parent_;
|
||||
}
|
||||
|
||||
unsigned argument::get_arg_no() const {
|
||||
return arg_no_;
|
||||
}
|
||||
|
||||
|
||||
/* function */
|
||||
function::function(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *parent)
|
||||
|
@@ -109,6 +109,31 @@ std::string binary_operator::repr_impl() const {
|
||||
}
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_div() const {
|
||||
return op_ == llop::UDiv || op_ == llop::SDiv;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_rem() const {
|
||||
return op_ == llop::URem || op_ == llop::SRem;
|
||||
}
|
||||
|
||||
bool binary_operator::is_shl() const {
|
||||
return op_ == llop::Shl;
|
||||
}
|
||||
|
||||
bool binary_operator::is_shr() const {
|
||||
return op_ == llop::LShr || op_ == llop::AShr;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_mult() const {
|
||||
return op_ == llop::Mul;
|
||||
}
|
||||
|
||||
bool binary_operator::is_int_add_sub() const {
|
||||
return op_ == llop::Add || llop::Sub;
|
||||
}
|
||||
|
||||
|
||||
binary_operator::binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
|
||||
: instruction(ty, 2, 1, name, next), op_(op){
|
||||
set_operand(0, lhs);
|
||||
|
Reference in New Issue
Block a user