[codegen] rough template for axis_info pass

This commit is contained in:
Philippe Tillet
2019-06-24 18:57:32 -07:00
parent 72867d17d4
commit edc31cabb0
9 changed files with 222 additions and 1 deletions

View File

@@ -30,7 +30,7 @@ void matmul(restrict read_only align(4) fp16 *A,
restrict read_only align(4) fp16 *B,
align(4) fp32 *C,
int32 M, int32 N, int32 K,
int32 lda, int32 ldb, int32 ldc,
multiple_of(4) int32 lda, multiple_of(4) int32 ldb, multiple_of(4) int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);

View File

@@ -0,0 +1,39 @@
#ifndef TDL_INCLUDE_CODEGEN_AXIS_INFO_PASS_H
#define TDL_INCLUDE_CODEGEN_AXIS_INFO_PASS_H
#include <set>
#include <map>
namespace triton {
namespace ir {
class value;
class module;
}
namespace codegen{
class axis_info {
private:
// helpers
bool is_first_axis_unit(ir::value *x);
// populate maps
bool populate_is_constant(ir::value *i);
unsigned populate_max_contiguous(ir::value *i);
unsigned populate_multiple_of(ir::value *i);
public:
void run(ir::module &mod);
private:
std::map<ir::value*, bool> is_constant_;
std::map<ir::value*, unsigned> max_contiguous_;
std::map<ir::value*, unsigned> multiple_of_;
};
}
}
#endif

View File

@@ -5,6 +5,7 @@
#include <map>
#include "value.h"
#include "constant.h"
#include <iostream>
namespace triton{
namespace ir{
@@ -21,6 +22,8 @@ class argument: public value{
public:
static argument* create(type *ty, const std::string &name,
function *parent = nullptr, unsigned arg_no = 0);
function* get_parent() const;
unsigned get_arg_no() const;
private:
function *parent_;
@@ -53,6 +56,10 @@ public:
return value_;
}
bool is_llvm_attr() const {
return kind_ != multiple_of;
}
private:
attribute_kind_t kind_;
unsigned value_;
@@ -89,6 +96,7 @@ public:
// attributes
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
const attr_map_t &attrs() { return attrs_; }
std::set<attribute> get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
private:
module *parent_;

View File

@@ -122,6 +122,12 @@ public:
bool is_int_div_rem() const;
bool is_shift() const;
bool is_cast() const;
bool is_int_mult() const;
bool is_int_add_sub() const;
bool is_int_div() const;
bool is_int_rem() const;
bool is_shl() const;
bool is_shr() const;
// Wraps
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }

View File

@@ -17,6 +17,7 @@
#include "triton/codegen/shmem_liveness.h"
#include "triton/codegen/shmem_info.h"
#include "triton/codegen/shmem_barriers.h"
#include "triton/codegen/axis_info.h"
#include "triton/codegen/target.h"
#include "triton/codegen/vectorize.h"
#include <functional>
@@ -60,11 +61,13 @@ public:
optimize_dot(&tune),
optimize_cse(),
optimize_trans(),
axis_info(),
target_(target) { }
void target_independent(ir::module &module) {
optimize_dot.run(module);
optimize_trans.run(module);
axis_info.run(module);
// ir::print(module, std::cout);
}
@@ -88,6 +91,7 @@ public:
codegen::optimize_dot optimize_dot;
codegen::optimize_cse optimize_cse;
codegen::optimize_trans optimize_trans;
codegen::axis_info axis_info;
codegen::target* target_;
};

129
lib/codegen/axis_info.cpp Normal file
View File

@@ -0,0 +1,129 @@
#include "triton/codegen/axis_info.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
namespace triton {
namespace codegen{
template<class T>
inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
return map.insert(std::make_pair(i, value)).first->second;
}
bool axis_info::is_first_axis_unit(ir::value *x){
if(x->get_type()->is_tile_ty())
return x->get_type()->get_tile_shapes()[0]->get_value() == 1;
else
return true;
}
bool axis_info::populate_is_constant(ir::value *v) {
// helper for the cache
auto cache = [this,v](bool value){ return add_to_cache(v, value, is_constant_); };
// populate
if(v->get_type()->is_tile_ty()){
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
bool value = populate_is_constant(x->get_operand(0));
// check if broadcast (i.e., constant) along contiguous dimension
if(is_first_axis_unit(x->get_operand(0))
&& !is_first_axis_unit(x))
return cache(value);
}
// otherwise the tile is not constant in the contiguous dimension
return cache(false);
}
// scalars are always constant in the contiguous dimension
return cache(true);
}
unsigned axis_info::populate_max_contiguous(ir::value *v){
// helper for the cache
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); };
// populate
if(v->get_type()->is_tile_ty()){
auto shapes = v->get_type()->get_tile_shapes();
if(dynamic_cast<ir::get_global_range_inst*>(v))
return cache(shapes[0]->get_value());
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
bool lhs_has_cst = populate_is_constant(lhs);
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
bool rhs_has_cst = populate_is_constant(rhs);
if(x->is_int_add_sub()){
if(lhs_has_cst)
return cache(rhs_max_contiguous);
if(rhs_has_cst)
return cache(lhs_max_contiguous);
}
}
}
return cache(1);
}
unsigned axis_info::populate_multiple_of(ir::value *v){
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); };
if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
for(auto attr: attributes){
if(attr.get_kind() == ir::multiple_of)
return cache(attr.get_value());
}
}
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
int lhs = populate_multiple_of(x->get_operand(0));
int rhs = populate_multiple_of(x->get_operand(1));
if(x->is_int_mult())
return cache(lhs * rhs);
if(x->is_int_add_sub())
return cache(std::min(lhs, rhs));
if(x->is_int_div())
return cache(std::max(lhs / rhs, 1));
if(x->is_int_rem())
return cache(std::max(lhs % rhs, 1));
if(x->is_shl())
return cache(lhs << rhs);
if(x->is_shr())
return cache(std::max(lhs >> rhs, 1));
}
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
return cache(populate_multiple_of(x->get_operand(0)));
}
return cache(1);
}
void axis_info::run(ir::module &mod) {
// populate constant
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
populate_is_constant(i);
}
// populate multiple_of
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
populate_multiple_of(i);
}
// populate maximum contiguous
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
populate_max_contiguous(i);
}
}
}
}

View File

@@ -1125,6 +1125,7 @@ void selection::run(ir::module &src, Module &dst) {
for(auto attr_pair: fn->attrs()){
unsigned id = attr_pair.first;
for(ir::attribute attr: attr_pair.second)
if(attr.is_llvm_attr())
dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr));
}
tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);

View File

@@ -16,6 +16,15 @@ argument *argument::create(type *ty, const std::string &name,
return new argument(ty, name, parent, arg_no);
}
function* argument::get_parent() const {
return parent_;
}
unsigned argument::get_arg_no() const {
return arg_no_;
}
/* function */
function::function(function_type *ty, linkage_types_t linkage,
const std::string &name, module *parent)

View File

@@ -109,6 +109,31 @@ std::string binary_operator::repr_impl() const {
}
}
bool binary_operator::is_int_div() const {
return op_ == llop::UDiv || op_ == llop::SDiv;
}
bool binary_operator::is_int_rem() const {
return op_ == llop::URem || op_ == llop::SRem;
}
bool binary_operator::is_shl() const {
return op_ == llop::Shl;
}
bool binary_operator::is_shr() const {
return op_ == llop::LShr || op_ == llop::AShr;
}
bool binary_operator::is_int_mult() const {
return op_ == llop::Mul;
}
bool binary_operator::is_int_add_sub() const {
return op_ == llop::Add || llop::Sub;
}
binary_operator::binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
: instruction(ty, 2, 1, name, next), op_(op){
set_operand(0, lhs);