[auto-coalesce] more bugfixes

This commit is contained in:
Philippe Tillet
2019-09-16 13:28:23 -04:00
parent 8d37a55a21
commit e184bad9a1
8 changed files with 128 additions and 93 deletions

View File

@@ -19,14 +19,11 @@ namespace ir{
namespace codegen{
namespace transform{
class coalesce;
}
namespace analysis{
class axes;
class layout;
class align;
class tiles {
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
@@ -35,25 +32,27 @@ private:
void init_scanline_tile(ir::value *i);
public:
tiles(size_t num_warps, transform::coalesce* coalesce, analysis::axes* axes, analysis::layout* layout);
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
void run(ir::module &mod);
bool hmma(ir::value *value);
int mts(ir::value *value, unsigned ax);
int nts(ir::value *value, unsigned ax);
int fpw(ir::value *value, unsigned ax);
int wpt(ir::value *value, unsigned ax);
std::vector<int> order(ir::value *v);
const std::map<int, ir::value*>& largest();
private:
// dependencies
analysis::align* align_;
analysis::layout* layout_;
analysis::axes* axes_;
transform::coalesce* coalesce_;
// number of warps
size_t num_warps_;
// tile properties
std::map<int, bool> hmma_;
std::map<int, ir::value*> largest_;
std::map<int, std::vector<int>> order_;
std::map<int, bool> hmma_;
std::map<int, int> fpw_;
std::map<int, int> wpt_;
std::map<int, int> mts_;

View File

@@ -2,6 +2,7 @@
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
#include <map>
#include <set>
#include <vector>
namespace triton {
@@ -9,27 +10,32 @@ namespace triton {
namespace ir {
class module;
class value;
class io_inst;
}
namespace codegen{
namespace analysis{
class align;
class layout;
class meminfo;
}
namespace transform{
class coalesce {
private:
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
public:
coalesce(analysis::align* algin, analysis::meminfo* mem);
std::vector<unsigned> get_order(ir::value* v);
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::meminfo* mem);
void run(ir::module &mod);
private:
analysis::align* align_;
analysis::layout* layout_;
analysis::meminfo* mem_;
std::map<ir::value*, std::vector<unsigned>> order_;
};
}

View File

@@ -155,6 +155,8 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
return is_constant_.at(v);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(v, {cst_info{true, (unsigned)x->get_value()}}, is_constant_);
if(dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_is_constant_phi(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
@@ -300,6 +302,8 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
auto shapes = v->get_type()->get_tile_shapes();
if(dynamic_cast<ir::make_range*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
if(dynamic_cast<ir::make_range_sta*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
}

View File

@@ -1,9 +1,10 @@
#include <algorithm>
#include <cstdlib>
#include <numeric>
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/tiles.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "triton/ir/module.h"
@@ -18,8 +19,8 @@ namespace triton{
namespace codegen{
namespace analysis{
tiles::tiles(size_t num_warps, transform::coalesce *reorder, analysis::axes *axes, analysis::layout *layout):
num_warps_(num_warps), coalesce_(reorder), axes_(axes), layout_(layout)
tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, analysis::layout *layout):
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
{ }
bool is_hmma(ir::value *v){
@@ -57,6 +58,11 @@ int tiles::wpt(ir::value *value, unsigned ax) {
return wpt_.at(axes_->get(value, ax));
}
std::vector<int> tiles::order(ir::value *v) {
auto ret = order_[layout_->id(v)];
return ret;
}
const std::map<int, ir::value*>& tiles::largest() {
return largest_;
}
@@ -68,10 +74,10 @@ unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
void tiles::init_hmma_tile(ir::value *i) {
auto order = coalesce_->get_order(i);
auto ord = order(i);
auto shapes = i->get_type()->get_tile_shapes();
unsigned shape_0 = shapes[order[0]];
unsigned shape_1 = shapes[order[1]];
unsigned shape_0 = shapes[ord[0]];
unsigned shape_1 = shapes[ord[1]];
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
std::vector<unsigned> fpw = {1, 1, 1};
@@ -110,17 +116,17 @@ void tiles::init_hmma_tile(ir::value *i) {
}
void tiles::init_scanline_tile(ir::value *i) {
auto order = coalesce_->get_order(i);
auto ord = order(i);
auto shapes = i->get_type()->get_tile_shapes();
unsigned size = i->get_type()->get_tile_num_elements();
unsigned ld = order[0];
unsigned ld = ord[0];
unsigned num_threads = num_warps_*32;
unsigned current = num_threads;
nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4);
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]);
current = current / mts_[axes_->get(i, ld)];
for(size_t d = 1; d < shapes.size(); d++){
ld = order[d];
ld = ord[d];
nts_[axes_->get(i, ld)] = 1;
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]);
current = current / mts_[axes_->get(i, ld)];
@@ -133,19 +139,20 @@ void tiles::init_scanline_tile(ir::value *i) {
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(i);
}
}
void tiles::run(ir::module &) {
hmma_.clear();
largest_.clear();
size_t num_groups = layout_->get_num_groups();
// find out which groups require hmma layout
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i);
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma);
}
// find out which value is the largest in each group
// std::vector<unsigned> axes;
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i);
// helpers
auto rank = [](ir::value* v) {
ir::type *ty = v->get_type();
size_t ret = 0;
@@ -154,10 +161,36 @@ void tiles::run(ir::module &) {
ret += s > 1;
return ret;
};
// find out which groups require hmma layout
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i);
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma);
}
// find out which value is the largest in each group
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i);
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
}
// find out the order of a group
for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io;
for(ir::value* v: layout_->values(i))
extract_io_use(v, io);
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
};
auto it = std::max_element(io.begin(), io.end(), cmp);
std::vector<int> order(rank(largest_[i]));
std::iota(order.begin(), order.end(), 0);
if(it != io.end()) {
auto max_contiguous = align_->contiguous((*it)->get_pointer_operand());
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b]; }
);
}
order_[i] = order;
}
// tiling parameters
for(auto x: largest_){
ir::value *i = x.second;

View File

@@ -545,7 +545,7 @@ Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) {
* ------------------- */
// Grid construction
std::vector<Value*> delinearize(Value *trailing, const std::vector<unsigned>& order, std::vector<unsigned> &shapes, IRBuilder<> &builder){
std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<unsigned> &shapes, IRBuilder<> &builder){
size_t dim = shapes.size();
std::vector<Value*> result(dim);
for(unsigned k = 0; k < dim - 1; k++){
@@ -562,7 +562,7 @@ inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
inline void to_warps(const std::vector<unsigned> &bs, const std::vector<unsigned>& order, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
inline void to_warps(const std::vector<unsigned> &bs, const std::vector<int>& order, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
static const size_t warp_size = 32;
size_t nthreads = 1, nwarps = 1;
nw.resize(bs.size());
@@ -578,7 +578,7 @@ inline void to_warps(const std::vector<unsigned> &bs, const std::vector<unsigned
}
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
auto order = reorder_->get_order(v);
auto order = tiles_->order(v);
const auto& shapes = v->get_type()->get_tile_shapes();
size_t dim = shapes.size();
std::vector<unsigned> contiguous(dim);

View File

@@ -6,6 +6,7 @@
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/module.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/meminfo.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/coalesce.h"
@@ -14,62 +15,59 @@ namespace triton {
namespace codegen{
namespace transform{
coalesce::coalesce(analysis::align* align, analysis::meminfo *mem)
: align_(align), mem_(mem) { }
coalesce::coalesce(analysis::align* align, analysis::layout *layouts, analysis::meminfo *mem)
: align_(align), layout_(layouts), mem_(mem) { }
std::vector<unsigned> coalesce::get_order(ir::value* v) {
return order_.at(v);
// Find all values that are used as pointer operands in LD/ST
void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(i);
}
}
void coalesce::extract_ld(ir::io_inst* i, std::map<int, std::vector<ir::io_inst*>>& result) {
ir::value *ptr = i->get_pointer_operand();
auto contiguous = align_->contiguous(ptr);
auto it = std::max_element(contiguous.begin(), contiguous.end());
int axis = std::distance(contiguous.begin(), it);
result[axis].push_back(i);
}
void coalesce::run(ir::module &mod) {
// find values to rematerialize
size_t num_groups = layout_->get_num_groups();
std::vector<ir::io_inst*> remat;
for(size_t id = 0; id < num_groups; id++) {
const auto& values = layout_->values(id);
// extract pointers used in ld/st operations
std::set<ir::io_inst*> io;
std::function<void(ir::value*)> set_order = [&](ir::value *v) -> void {
if(order_.find(v) != order_.end())
return;
ir::type *tile_ty = v->get_type();
if(auto *x = dynamic_cast<ir::store_inst*>(v))
tile_ty = x->get_operand(0)->get_type();
if(!tile_ty->is_tile_ty())
return;
std::vector<unsigned> order(tile_ty->get_tile_shapes().size());
std::iota(order.begin(), order.end(), 0);
order_[v] = order;
if(ir::user* u = dynamic_cast<ir::user*>(v))
for(ir::value* op: u->ops())
set_order(op);
};
// initialize work-list
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: ir::cfg::reverse_post_order(fn))
for(ir::instruction *i: block->get_inst_list()){
if(auto *x = dynamic_cast<ir::io_inst*>(i)) {
ir::type* ptr_ty = x->get_pointer_operand()->get_type();
if(ptr_ty->is_tile_ty())
io.insert(x);
}
set_order(i);
for(ir::value *v: values)
extract_io_use(v, io);
// extract leading axes
std::map<int, std::vector<ir::io_inst*>> axes;
for(ir::io_inst *i: io)
extract_ld(i, axes);
// update list of values to rematerialize
if(axes.empty())
continue;
for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
remat.insert(remat.begin(),
it->second.begin(), it->second.end());
}
// rematerialize values
ir::builder &builder = mod.get_builder();
std::map<ir::value*, ir::value*> replaced;
for(ir::io_inst *i: io) {
ir::value *ptr = i->get_pointer_operand();
auto max_contiguous = align_->contiguous(ptr);
std::vector<unsigned> order(max_contiguous.size());
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
for(ir::io_inst *r: remat) {
std::list<std::pair<ir::instruction*, ir::instruction*>> work_list;
if(order != order_[i])
work_list.push_back({i, nullptr});
std::map<ir::value*, ir::value*> replaced;
work_list.push_back({r, nullptr});
// rematerialize recursively
while(!work_list.empty()) {
auto pair = work_list.back();
ir::instruction* cloned = pair.first;
ir::instruction* original = pair.second;
order_[cloned] = order;
work_list.pop_back();
for(ir::value *op: cloned->ops()) {
ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
@@ -101,14 +99,12 @@ void coalesce::run(ir::module &mod) {
}
n_op = builder.insert(n_op);
replaced.insert({i_op, n_op});
order_[n_op] = order;
mem_->copy(n_op, i_op);
if(original)
n_op->erase_use(original);
cloned->replace_uses_of_with(i_op, n_op);
}
}
}
}

View File

@@ -200,30 +200,30 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
// create passes
codegen::analysis::meminfo shmem_info;
codegen::analysis::align alignment_info;
codegen::analysis::align align;
codegen::analysis::liveness shmem_liveness(&shmem_info);
codegen::transform::coalesce coalesce(&alignment_info, &shmem_info);
codegen::analysis::axes axes;
codegen::analysis::layout layouts(&axes);
codegen::analysis::tiles tiles(opt.num_warps, &coalesce, &axes, &layouts);
codegen::transform::coalesce coalesce(&align, &layouts, &shmem_info);
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &tiles);
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
codegen::transform::vectorize vectorize(&tiles);
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&alignment_info);
codegen::selection selection(&shmem_allocation, &tiles, &shmem_info, &alignment_info, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
codegen::transform::reassociate reassociate(&align);
codegen::selection selection(&shmem_allocation, &tiles, &shmem_info, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
// run passes
peephole.run(module);
dce.run(module);
alignment_info.run(module);
align.run(module);
shmem_info.run(module);
coalesce.run(module);
dce.run(module);
axes.run(module);
layouts.run(module);
coalesce.run(module);
align.run(module);
dce.run(module);
tiles.run(module);
alignment_info.run(module);
reassociate.run(module);
dce.run(module);
peephole.run(module);
@@ -236,12 +236,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
dce.run(module);
vectorize.run(module);
dce.run(module);
alignment_info.run(module);
coalesce.run(module);
dce.run(module);
// ir::print(module, std::cout);
axes.run(module);
layouts.run(module);
align.run(module);
tiles.run(module);
selection.run(module, *llvm);
// return binary

View File

@@ -45,10 +45,10 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"AT", {AT?"1":"0"}});
opt.defines.push_back({"BT", {BT?"1":"0"}});
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TM", {"64", "128"}});
opt.defines.push_back({"TN", {"64", "128"}});
opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {4};
opt.num_warps = {2, 4, 8};
// create function
rt::function function(src::dot, opt);
// benchmark available libraries