[auto-coalesce] more bugfixes
This commit is contained in:
@@ -19,14 +19,11 @@ namespace ir{
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace transform{
|
||||
class coalesce;
|
||||
}
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class axes;
|
||||
class layout;
|
||||
class align;
|
||||
|
||||
class tiles {
|
||||
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
|
||||
@@ -35,25 +32,27 @@ private:
|
||||
void init_scanline_tile(ir::value *i);
|
||||
|
||||
public:
|
||||
tiles(size_t num_warps, transform::coalesce* coalesce, analysis::axes* axes, analysis::layout* layout);
|
||||
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
|
||||
void run(ir::module &mod);
|
||||
bool hmma(ir::value *value);
|
||||
int mts(ir::value *value, unsigned ax);
|
||||
int nts(ir::value *value, unsigned ax);
|
||||
int fpw(ir::value *value, unsigned ax);
|
||||
int wpt(ir::value *value, unsigned ax);
|
||||
std::vector<int> order(ir::value *v);
|
||||
const std::map<int, ir::value*>& largest();
|
||||
|
||||
private:
|
||||
// dependencies
|
||||
analysis::align* align_;
|
||||
analysis::layout* layout_;
|
||||
analysis::axes* axes_;
|
||||
transform::coalesce* coalesce_;
|
||||
// number of warps
|
||||
size_t num_warps_;
|
||||
// tile properties
|
||||
std::map<int, bool> hmma_;
|
||||
std::map<int, ir::value*> largest_;
|
||||
std::map<int, std::vector<int>> order_;
|
||||
std::map<int, bool> hmma_;
|
||||
std::map<int, int> fpw_;
|
||||
std::map<int, int> wpt_;
|
||||
std::map<int, int> mts_;
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
@@ -9,27 +10,32 @@ namespace triton {
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class io_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class align;
|
||||
class layout;
|
||||
class meminfo;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class coalesce {
|
||||
private:
|
||||
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
|
||||
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
|
||||
|
||||
public:
|
||||
coalesce(analysis::align* algin, analysis::meminfo* mem);
|
||||
std::vector<unsigned> get_order(ir::value* v);
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::meminfo* mem);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::align* align_;
|
||||
analysis::layout* layout_;
|
||||
analysis::meminfo* mem_;
|
||||
std::map<ir::value*, std::vector<unsigned>> order_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -155,6 +155,8 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
return is_constant_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(v, {cst_info{true, (unsigned)x->get_value()}}, is_constant_);
|
||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_is_constant_phi(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
@@ -300,6 +302,8 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::make_range*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||
}
|
||||
|
||||
|
@@ -1,9 +1,10 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <numeric>
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -18,8 +19,8 @@ namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
tiles::tiles(size_t num_warps, transform::coalesce *reorder, analysis::axes *axes, analysis::layout *layout):
|
||||
num_warps_(num_warps), coalesce_(reorder), axes_(axes), layout_(layout)
|
||||
tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, analysis::layout *layout):
|
||||
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
|
||||
{ }
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
@@ -57,6 +58,11 @@ int tiles::wpt(ir::value *value, unsigned ax) {
|
||||
return wpt_.at(axes_->get(value, ax));
|
||||
}
|
||||
|
||||
std::vector<int> tiles::order(ir::value *v) {
|
||||
auto ret = order_[layout_->id(v)];
|
||||
return ret;
|
||||
}
|
||||
|
||||
const std::map<int, ir::value*>& tiles::largest() {
|
||||
return largest_;
|
||||
}
|
||||
@@ -68,10 +74,10 @@ unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
||||
|
||||
|
||||
void tiles::init_hmma_tile(ir::value *i) {
|
||||
auto order = coalesce_->get_order(i);
|
||||
auto ord = order(i);
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
unsigned shape_0 = shapes[order[0]];
|
||||
unsigned shape_1 = shapes[order[1]];
|
||||
unsigned shape_0 = shapes[ord[0]];
|
||||
unsigned shape_1 = shapes[ord[1]];
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> fpw = {1, 1, 1};
|
||||
@@ -110,17 +116,17 @@ void tiles::init_hmma_tile(ir::value *i) {
|
||||
}
|
||||
|
||||
void tiles::init_scanline_tile(ir::value *i) {
|
||||
auto order = coalesce_->get_order(i);
|
||||
auto ord = order(i);
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
unsigned size = i->get_type()->get_tile_num_elements();
|
||||
unsigned ld = order[0];
|
||||
unsigned ld = ord[0];
|
||||
unsigned num_threads = num_warps_*32;
|
||||
unsigned current = num_threads;
|
||||
nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4);
|
||||
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]);
|
||||
current = current / mts_[axes_->get(i, ld)];
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
ld = order[d];
|
||||
ld = ord[d];
|
||||
nts_[axes_->get(i, ld)] = 1;
|
||||
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]);
|
||||
current = current / mts_[axes_->get(i, ld)];
|
||||
@@ -133,19 +139,20 @@ void tiles::init_scanline_tile(ir::value *i) {
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
if(i && i->get_pointer_operand() == v)
|
||||
result.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void tiles::run(ir::module &) {
|
||||
hmma_.clear();
|
||||
largest_.clear();
|
||||
size_t num_groups = layout_->get_num_groups();
|
||||
// find out which groups require hmma layout
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values(i);
|
||||
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma);
|
||||
}
|
||||
// find out which value is the largest in each group
|
||||
// std::vector<unsigned> axes;
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values(i);
|
||||
// helpers
|
||||
auto rank = [](ir::value* v) {
|
||||
ir::type *ty = v->get_type();
|
||||
size_t ret = 0;
|
||||
@@ -154,10 +161,36 @@ void tiles::run(ir::module &) {
|
||||
ret += s > 1;
|
||||
return ret;
|
||||
};
|
||||
// find out which groups require hmma layout
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values(i);
|
||||
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma);
|
||||
}
|
||||
// find out which value is the largest in each group
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values(i);
|
||||
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
|
||||
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
||||
}
|
||||
|
||||
// find out the order of a group
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::set<ir::io_inst*> io;
|
||||
for(ir::value* v: layout_->values(i))
|
||||
extract_io_use(v, io);
|
||||
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
|
||||
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
|
||||
};
|
||||
auto it = std::max_element(io.begin(), io.end(), cmp);
|
||||
std::vector<int> order(rank(largest_[i]));
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
if(it != io.end()) {
|
||||
auto max_contiguous = align_->contiguous((*it)->get_pointer_operand());
|
||||
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
|
||||
return max_contiguous[a] > max_contiguous[b]; }
|
||||
);
|
||||
}
|
||||
order_[i] = order;
|
||||
}
|
||||
// tiling parameters
|
||||
for(auto x: largest_){
|
||||
ir::value *i = x.second;
|
||||
|
@@ -545,7 +545,7 @@ Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) {
|
||||
* ------------------- */
|
||||
|
||||
// Grid construction
|
||||
std::vector<Value*> delinearize(Value *trailing, const std::vector<unsigned>& order, std::vector<unsigned> &shapes, IRBuilder<> &builder){
|
||||
std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<unsigned> &shapes, IRBuilder<> &builder){
|
||||
size_t dim = shapes.size();
|
||||
std::vector<Value*> result(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
@@ -562,7 +562,7 @@ inline int32_t ceil(int32_t num, int32_t div){
|
||||
return (num + div - 1)/div;
|
||||
}
|
||||
|
||||
inline void to_warps(const std::vector<unsigned> &bs, const std::vector<unsigned>& order, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
|
||||
inline void to_warps(const std::vector<unsigned> &bs, const std::vector<int>& order, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
|
||||
static const size_t warp_size = 32;
|
||||
size_t nthreads = 1, nwarps = 1;
|
||||
nw.resize(bs.size());
|
||||
@@ -578,7 +578,7 @@ inline void to_warps(const std::vector<unsigned> &bs, const std::vector<unsigned
|
||||
}
|
||||
|
||||
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
auto order = reorder_->get_order(v);
|
||||
auto order = tiles_->order(v);
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
size_t dim = shapes.size();
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/meminfo.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
@@ -14,62 +15,59 @@ namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
coalesce::coalesce(analysis::align* align, analysis::meminfo *mem)
|
||||
: align_(align), mem_(mem) { }
|
||||
coalesce::coalesce(analysis::align* align, analysis::layout *layouts, analysis::meminfo *mem)
|
||||
: align_(align), layout_(layouts), mem_(mem) { }
|
||||
|
||||
std::vector<unsigned> coalesce::get_order(ir::value* v) {
|
||||
return order_.at(v);
|
||||
// Find all values that are used as pointer operands in LD/ST
|
||||
void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
if(i && i->get_pointer_operand() == v)
|
||||
result.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
void coalesce::extract_ld(ir::io_inst* i, std::map<int, std::vector<ir::io_inst*>>& result) {
|
||||
ir::value *ptr = i->get_pointer_operand();
|
||||
auto contiguous = align_->contiguous(ptr);
|
||||
auto it = std::max_element(contiguous.begin(), contiguous.end());
|
||||
int axis = std::distance(contiguous.begin(), it);
|
||||
result[axis].push_back(i);
|
||||
}
|
||||
|
||||
void coalesce::run(ir::module &mod) {
|
||||
|
||||
// find values to rematerialize
|
||||
size_t num_groups = layout_->get_num_groups();
|
||||
std::vector<ir::io_inst*> remat;
|
||||
for(size_t id = 0; id < num_groups; id++) {
|
||||
const auto& values = layout_->values(id);
|
||||
// extract pointers used in ld/st operations
|
||||
std::set<ir::io_inst*> io;
|
||||
|
||||
std::function<void(ir::value*)> set_order = [&](ir::value *v) -> void {
|
||||
if(order_.find(v) != order_.end())
|
||||
return;
|
||||
ir::type *tile_ty = v->get_type();
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(v))
|
||||
tile_ty = x->get_operand(0)->get_type();
|
||||
if(!tile_ty->is_tile_ty())
|
||||
return;
|
||||
std::vector<unsigned> order(tile_ty->get_tile_shapes().size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
order_[v] = order;
|
||||
if(ir::user* u = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value* op: u->ops())
|
||||
set_order(op);
|
||||
};
|
||||
|
||||
// initialize work-list
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: ir::cfg::reverse_post_order(fn))
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(auto *x = dynamic_cast<ir::io_inst*>(i)) {
|
||||
ir::type* ptr_ty = x->get_pointer_operand()->get_type();
|
||||
if(ptr_ty->is_tile_ty())
|
||||
io.insert(x);
|
||||
}
|
||||
set_order(i);
|
||||
for(ir::value *v: values)
|
||||
extract_io_use(v, io);
|
||||
// extract leading axes
|
||||
std::map<int, std::vector<ir::io_inst*>> axes;
|
||||
for(ir::io_inst *i: io)
|
||||
extract_ld(i, axes);
|
||||
// update list of values to rematerialize
|
||||
if(axes.empty())
|
||||
continue;
|
||||
for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
|
||||
remat.insert(remat.begin(),
|
||||
it->second.begin(), it->second.end());
|
||||
}
|
||||
|
||||
// rematerialize values
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::map<ir::value*, ir::value*> replaced;
|
||||
for(ir::io_inst *i: io) {
|
||||
ir::value *ptr = i->get_pointer_operand();
|
||||
auto max_contiguous = align_->contiguous(ptr);
|
||||
std::vector<unsigned> order(max_contiguous.size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
|
||||
for(ir::io_inst *r: remat) {
|
||||
std::list<std::pair<ir::instruction*, ir::instruction*>> work_list;
|
||||
if(order != order_[i])
|
||||
work_list.push_back({i, nullptr});
|
||||
std::map<ir::value*, ir::value*> replaced;
|
||||
work_list.push_back({r, nullptr});
|
||||
// rematerialize recursively
|
||||
while(!work_list.empty()) {
|
||||
auto pair = work_list.back();
|
||||
ir::instruction* cloned = pair.first;
|
||||
ir::instruction* original = pair.second;
|
||||
order_[cloned] = order;
|
||||
work_list.pop_back();
|
||||
for(ir::value *op: cloned->ops()) {
|
||||
ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
|
||||
@@ -101,14 +99,12 @@ void coalesce::run(ir::module &mod) {
|
||||
}
|
||||
n_op = builder.insert(n_op);
|
||||
replaced.insert({i_op, n_op});
|
||||
order_[n_op] = order;
|
||||
mem_->copy(n_op, i_op);
|
||||
if(original)
|
||||
n_op->erase_use(original);
|
||||
cloned->replace_uses_of_with(i_op, n_op);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -200,30 +200,30 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
// create passes
|
||||
codegen::analysis::meminfo shmem_info;
|
||||
codegen::analysis::align alignment_info;
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
||||
codegen::transform::coalesce coalesce(&alignment_info, &shmem_info);
|
||||
codegen::analysis::axes axes;
|
||||
codegen::analysis::layout layouts(&axes);
|
||||
codegen::analysis::tiles tiles(opt.num_warps, &coalesce, &axes, &layouts);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts, &shmem_info);
|
||||
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
|
||||
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &tiles);
|
||||
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
||||
codegen::transform::vectorize vectorize(&tiles);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&alignment_info);
|
||||
codegen::selection selection(&shmem_allocation, &tiles, &shmem_info, &alignment_info, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
||||
codegen::transform::reassociate reassociate(&align);
|
||||
codegen::selection selection(&shmem_allocation, &tiles, &shmem_info, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
alignment_info.run(module);
|
||||
align.run(module);
|
||||
shmem_info.run(module);
|
||||
coalesce.run(module);
|
||||
dce.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
coalesce.run(module);
|
||||
align.run(module);
|
||||
dce.run(module);
|
||||
tiles.run(module);
|
||||
alignment_info.run(module);
|
||||
reassociate.run(module);
|
||||
dce.run(module);
|
||||
peephole.run(module);
|
||||
@@ -236,12 +236,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
dce.run(module);
|
||||
vectorize.run(module);
|
||||
dce.run(module);
|
||||
alignment_info.run(module);
|
||||
coalesce.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
align.run(module);
|
||||
tiles.run(module);
|
||||
selection.run(module, *llvm);
|
||||
// return binary
|
||||
|
@@ -45,10 +45,10 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TM", {"64", "128"}});
|
||||
opt.defines.push_back({"TN", {"64", "128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.num_warps = {4};
|
||||
opt.num_warps = {2, 4, 8};
|
||||
// create function
|
||||
rt::function function(src::dot, opt);
|
||||
// benchmark available libraries
|
||||
|
Reference in New Issue
Block a user