[auto-coalesce] more bugfixes
This commit is contained in:
@@ -19,14 +19,11 @@ namespace ir{
|
|||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
|
||||||
namespace transform{
|
|
||||||
class coalesce;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace analysis{
|
namespace analysis{
|
||||||
|
|
||||||
class axes;
|
class axes;
|
||||||
class layout;
|
class layout;
|
||||||
|
class align;
|
||||||
|
|
||||||
class tiles {
|
class tiles {
|
||||||
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
|
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
|
||||||
@@ -35,25 +32,27 @@ private:
|
|||||||
void init_scanline_tile(ir::value *i);
|
void init_scanline_tile(ir::value *i);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
tiles(size_t num_warps, transform::coalesce* coalesce, analysis::axes* axes, analysis::layout* layout);
|
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
bool hmma(ir::value *value);
|
bool hmma(ir::value *value);
|
||||||
int mts(ir::value *value, unsigned ax);
|
int mts(ir::value *value, unsigned ax);
|
||||||
int nts(ir::value *value, unsigned ax);
|
int nts(ir::value *value, unsigned ax);
|
||||||
int fpw(ir::value *value, unsigned ax);
|
int fpw(ir::value *value, unsigned ax);
|
||||||
int wpt(ir::value *value, unsigned ax);
|
int wpt(ir::value *value, unsigned ax);
|
||||||
|
std::vector<int> order(ir::value *v);
|
||||||
const std::map<int, ir::value*>& largest();
|
const std::map<int, ir::value*>& largest();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// dependencies
|
// dependencies
|
||||||
|
analysis::align* align_;
|
||||||
analysis::layout* layout_;
|
analysis::layout* layout_;
|
||||||
analysis::axes* axes_;
|
analysis::axes* axes_;
|
||||||
transform::coalesce* coalesce_;
|
|
||||||
// number of warps
|
// number of warps
|
||||||
size_t num_warps_;
|
size_t num_warps_;
|
||||||
// tile properties
|
// tile properties
|
||||||
std::map<int, bool> hmma_;
|
|
||||||
std::map<int, ir::value*> largest_;
|
std::map<int, ir::value*> largest_;
|
||||||
|
std::map<int, std::vector<int>> order_;
|
||||||
|
std::map<int, bool> hmma_;
|
||||||
std::map<int, int> fpw_;
|
std::map<int, int> fpw_;
|
||||||
std::map<int, int> wpt_;
|
std::map<int, int> wpt_;
|
||||||
std::map<int, int> mts_;
|
std::map<int, int> mts_;
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
|
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
@@ -9,27 +10,32 @@ namespace triton {
|
|||||||
namespace ir {
|
namespace ir {
|
||||||
class module;
|
class module;
|
||||||
class value;
|
class value;
|
||||||
|
class io_inst;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
|
||||||
namespace analysis{
|
namespace analysis{
|
||||||
class align;
|
class align;
|
||||||
|
class layout;
|
||||||
class meminfo;
|
class meminfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace transform{
|
namespace transform{
|
||||||
|
|
||||||
class coalesce {
|
class coalesce {
|
||||||
|
private:
|
||||||
|
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
|
||||||
|
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
coalesce(analysis::align* algin, analysis::meminfo* mem);
|
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::meminfo* mem);
|
||||||
std::vector<unsigned> get_order(ir::value* v);
|
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
analysis::align* align_;
|
analysis::align* align_;
|
||||||
|
analysis::layout* layout_;
|
||||||
analysis::meminfo* mem_;
|
analysis::meminfo* mem_;
|
||||||
std::map<ir::value*, std::vector<unsigned>> order_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -155,6 +155,8 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
|||||||
return is_constant_.at(v);
|
return is_constant_.at(v);
|
||||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||||
return add_to_cache(v, {cst_info{true, (unsigned)x->get_value()}}, is_constant_);
|
return add_to_cache(v, {cst_info{true, (unsigned)x->get_value()}}, is_constant_);
|
||||||
|
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||||
|
return add_to_cache(v, {cst_info{true, 0}}, is_constant_);
|
||||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||||
return populate_is_constant_phi(x);
|
return populate_is_constant_phi(x);
|
||||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||||
@@ -300,6 +302,8 @@ std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
|||||||
auto shapes = v->get_type()->get_tile_shapes();
|
auto shapes = v->get_type()->get_tile_shapes();
|
||||||
if(dynamic_cast<ir::make_range*>(v))
|
if(dynamic_cast<ir::make_range*>(v))
|
||||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||||
|
if(dynamic_cast<ir::make_range_sta*>(v))
|
||||||
|
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,9 +1,10 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
#include <numeric>
|
||||||
|
#include "triton/codegen/analysis/align.h"
|
||||||
#include "triton/codegen/analysis/axes.h"
|
#include "triton/codegen/analysis/axes.h"
|
||||||
#include "triton/codegen/analysis/tiles.h"
|
#include "triton/codegen/analysis/tiles.h"
|
||||||
#include "triton/codegen/analysis/layout.h"
|
#include "triton/codegen/analysis/layout.h"
|
||||||
#include "triton/codegen/transform/coalesce.h"
|
|
||||||
#include "triton/ir/instructions.h"
|
#include "triton/ir/instructions.h"
|
||||||
#include "triton/ir/type.h"
|
#include "triton/ir/type.h"
|
||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
@@ -18,8 +19,8 @@ namespace triton{
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
namespace analysis{
|
namespace analysis{
|
||||||
|
|
||||||
tiles::tiles(size_t num_warps, transform::coalesce *reorder, analysis::axes *axes, analysis::layout *layout):
|
tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, analysis::layout *layout):
|
||||||
num_warps_(num_warps), coalesce_(reorder), axes_(axes), layout_(layout)
|
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
bool is_hmma(ir::value *v){
|
bool is_hmma(ir::value *v){
|
||||||
@@ -57,6 +58,11 @@ int tiles::wpt(ir::value *value, unsigned ax) {
|
|||||||
return wpt_.at(axes_->get(value, ax));
|
return wpt_.at(axes_->get(value, ax));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> tiles::order(ir::value *v) {
|
||||||
|
auto ret = order_[layout_->id(v)];
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
const std::map<int, ir::value*>& tiles::largest() {
|
const std::map<int, ir::value*>& tiles::largest() {
|
||||||
return largest_;
|
return largest_;
|
||||||
}
|
}
|
||||||
@@ -68,10 +74,10 @@ unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
|||||||
|
|
||||||
|
|
||||||
void tiles::init_hmma_tile(ir::value *i) {
|
void tiles::init_hmma_tile(ir::value *i) {
|
||||||
auto order = coalesce_->get_order(i);
|
auto ord = order(i);
|
||||||
auto shapes = i->get_type()->get_tile_shapes();
|
auto shapes = i->get_type()->get_tile_shapes();
|
||||||
unsigned shape_0 = shapes[order[0]];
|
unsigned shape_0 = shapes[ord[0]];
|
||||||
unsigned shape_1 = shapes[order[1]];
|
unsigned shape_1 = shapes[ord[1]];
|
||||||
/* fragments per warp */
|
/* fragments per warp */
|
||||||
// try to make things as square as possible to maximize data re-use
|
// try to make things as square as possible to maximize data re-use
|
||||||
std::vector<unsigned> fpw = {1, 1, 1};
|
std::vector<unsigned> fpw = {1, 1, 1};
|
||||||
@@ -110,17 +116,17 @@ void tiles::init_hmma_tile(ir::value *i) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void tiles::init_scanline_tile(ir::value *i) {
|
void tiles::init_scanline_tile(ir::value *i) {
|
||||||
auto order = coalesce_->get_order(i);
|
auto ord = order(i);
|
||||||
auto shapes = i->get_type()->get_tile_shapes();
|
auto shapes = i->get_type()->get_tile_shapes();
|
||||||
unsigned size = i->get_type()->get_tile_num_elements();
|
unsigned size = i->get_type()->get_tile_num_elements();
|
||||||
unsigned ld = order[0];
|
unsigned ld = ord[0];
|
||||||
unsigned num_threads = num_warps_*32;
|
unsigned num_threads = num_warps_*32;
|
||||||
unsigned current = num_threads;
|
unsigned current = num_threads;
|
||||||
nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4);
|
nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4);
|
||||||
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]);
|
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]);
|
||||||
current = current / mts_[axes_->get(i, ld)];
|
current = current / mts_[axes_->get(i, ld)];
|
||||||
for(size_t d = 1; d < shapes.size(); d++){
|
for(size_t d = 1; d < shapes.size(); d++){
|
||||||
ld = order[d];
|
ld = ord[d];
|
||||||
nts_[axes_->get(i, ld)] = 1;
|
nts_[axes_->get(i, ld)] = 1;
|
||||||
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]);
|
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]);
|
||||||
current = current / mts_[axes_->get(i, ld)];
|
current = current / mts_[axes_->get(i, ld)];
|
||||||
@@ -133,19 +139,20 @@ void tiles::init_scanline_tile(ir::value *i) {
|
|||||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
||||||
|
for(ir::user* u: v->get_users()){
|
||||||
|
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||||
|
if(i && i->get_pointer_operand() == v)
|
||||||
|
result.insert(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void tiles::run(ir::module &) {
|
void tiles::run(ir::module &) {
|
||||||
hmma_.clear();
|
hmma_.clear();
|
||||||
largest_.clear();
|
largest_.clear();
|
||||||
size_t num_groups = layout_->get_num_groups();
|
size_t num_groups = layout_->get_num_groups();
|
||||||
// find out which groups require hmma layout
|
// helpers
|
||||||
for(size_t i = 0; i < num_groups; i++) {
|
|
||||||
const auto& values = layout_->values(i);
|
|
||||||
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma);
|
|
||||||
}
|
|
||||||
// find out which value is the largest in each group
|
|
||||||
// std::vector<unsigned> axes;
|
|
||||||
for(size_t i = 0; i < num_groups; i++) {
|
|
||||||
const auto& values = layout_->values(i);
|
|
||||||
auto rank = [](ir::value* v) {
|
auto rank = [](ir::value* v) {
|
||||||
ir::type *ty = v->get_type();
|
ir::type *ty = v->get_type();
|
||||||
size_t ret = 0;
|
size_t ret = 0;
|
||||||
@@ -154,10 +161,36 @@ void tiles::run(ir::module &) {
|
|||||||
ret += s > 1;
|
ret += s > 1;
|
||||||
return ret;
|
return ret;
|
||||||
};
|
};
|
||||||
|
// find out which groups require hmma layout
|
||||||
|
for(size_t i = 0; i < num_groups; i++) {
|
||||||
|
const auto& values = layout_->values(i);
|
||||||
|
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma);
|
||||||
|
}
|
||||||
|
// find out which value is the largest in each group
|
||||||
|
for(size_t i = 0; i < num_groups; i++) {
|
||||||
|
const auto& values = layout_->values(i);
|
||||||
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
|
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
|
||||||
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
||||||
}
|
}
|
||||||
|
// find out the order of a group
|
||||||
|
for(size_t i = 0; i < num_groups; i++){
|
||||||
|
std::set<ir::io_inst*> io;
|
||||||
|
for(ir::value* v: layout_->values(i))
|
||||||
|
extract_io_use(v, io);
|
||||||
|
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
|
||||||
|
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
|
||||||
|
};
|
||||||
|
auto it = std::max_element(io.begin(), io.end(), cmp);
|
||||||
|
std::vector<int> order(rank(largest_[i]));
|
||||||
|
std::iota(order.begin(), order.end(), 0);
|
||||||
|
if(it != io.end()) {
|
||||||
|
auto max_contiguous = align_->contiguous((*it)->get_pointer_operand());
|
||||||
|
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
|
||||||
|
return max_contiguous[a] > max_contiguous[b]; }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
order_[i] = order;
|
||||||
|
}
|
||||||
// tiling parameters
|
// tiling parameters
|
||||||
for(auto x: largest_){
|
for(auto x: largest_){
|
||||||
ir::value *i = x.second;
|
ir::value *i = x.second;
|
||||||
|
@@ -545,7 +545,7 @@ Value* selection::llvm_value(ir::value *v, IRBuilder<> &builder) {
|
|||||||
* ------------------- */
|
* ------------------- */
|
||||||
|
|
||||||
// Grid construction
|
// Grid construction
|
||||||
std::vector<Value*> delinearize(Value *trailing, const std::vector<unsigned>& order, std::vector<unsigned> &shapes, IRBuilder<> &builder){
|
std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<unsigned> &shapes, IRBuilder<> &builder){
|
||||||
size_t dim = shapes.size();
|
size_t dim = shapes.size();
|
||||||
std::vector<Value*> result(dim);
|
std::vector<Value*> result(dim);
|
||||||
for(unsigned k = 0; k < dim - 1; k++){
|
for(unsigned k = 0; k < dim - 1; k++){
|
||||||
@@ -562,7 +562,7 @@ inline int32_t ceil(int32_t num, int32_t div){
|
|||||||
return (num + div - 1)/div;
|
return (num + div - 1)/div;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void to_warps(const std::vector<unsigned> &bs, const std::vector<unsigned>& order, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
|
inline void to_warps(const std::vector<unsigned> &bs, const std::vector<int>& order, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
|
||||||
static const size_t warp_size = 32;
|
static const size_t warp_size = 32;
|
||||||
size_t nthreads = 1, nwarps = 1;
|
size_t nthreads = 1, nwarps = 1;
|
||||||
nw.resize(bs.size());
|
nw.resize(bs.size());
|
||||||
@@ -578,7 +578,7 @@ inline void to_warps(const std::vector<unsigned> &bs, const std::vector<unsigned
|
|||||||
}
|
}
|
||||||
|
|
||||||
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||||
auto order = reorder_->get_order(v);
|
auto order = tiles_->order(v);
|
||||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||||
size_t dim = shapes.size();
|
size_t dim = shapes.size();
|
||||||
std::vector<unsigned> contiguous(dim);
|
std::vector<unsigned> contiguous(dim);
|
||||||
|
@@ -6,6 +6,7 @@
|
|||||||
#include "triton/ir/basic_block.h"
|
#include "triton/ir/basic_block.h"
|
||||||
#include "triton/ir/instructions.h"
|
#include "triton/ir/instructions.h"
|
||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
|
#include "triton/codegen/analysis/layout.h"
|
||||||
#include "triton/codegen/analysis/meminfo.h"
|
#include "triton/codegen/analysis/meminfo.h"
|
||||||
#include "triton/codegen/analysis/align.h"
|
#include "triton/codegen/analysis/align.h"
|
||||||
#include "triton/codegen/transform/coalesce.h"
|
#include "triton/codegen/transform/coalesce.h"
|
||||||
@@ -14,62 +15,59 @@ namespace triton {
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
namespace transform{
|
namespace transform{
|
||||||
|
|
||||||
coalesce::coalesce(analysis::align* align, analysis::meminfo *mem)
|
coalesce::coalesce(analysis::align* align, analysis::layout *layouts, analysis::meminfo *mem)
|
||||||
: align_(align), mem_(mem) { }
|
: align_(align), layout_(layouts), mem_(mem) { }
|
||||||
|
|
||||||
std::vector<unsigned> coalesce::get_order(ir::value* v) {
|
// Find all values that are used as pointer operands in LD/ST
|
||||||
return order_.at(v);
|
void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
||||||
|
for(ir::user* u: v->get_users()){
|
||||||
|
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||||
|
if(i && i->get_pointer_operand() == v)
|
||||||
|
result.insert(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void coalesce::extract_ld(ir::io_inst* i, std::map<int, std::vector<ir::io_inst*>>& result) {
|
||||||
|
ir::value *ptr = i->get_pointer_operand();
|
||||||
|
auto contiguous = align_->contiguous(ptr);
|
||||||
|
auto it = std::max_element(contiguous.begin(), contiguous.end());
|
||||||
|
int axis = std::distance(contiguous.begin(), it);
|
||||||
|
result[axis].push_back(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
void coalesce::run(ir::module &mod) {
|
void coalesce::run(ir::module &mod) {
|
||||||
|
// find values to rematerialize
|
||||||
|
size_t num_groups = layout_->get_num_groups();
|
||||||
|
std::vector<ir::io_inst*> remat;
|
||||||
|
for(size_t id = 0; id < num_groups; id++) {
|
||||||
|
const auto& values = layout_->values(id);
|
||||||
|
// extract pointers used in ld/st operations
|
||||||
std::set<ir::io_inst*> io;
|
std::set<ir::io_inst*> io;
|
||||||
|
for(ir::value *v: values)
|
||||||
std::function<void(ir::value*)> set_order = [&](ir::value *v) -> void {
|
extract_io_use(v, io);
|
||||||
if(order_.find(v) != order_.end())
|
// extract leading axes
|
||||||
return;
|
std::map<int, std::vector<ir::io_inst*>> axes;
|
||||||
ir::type *tile_ty = v->get_type();
|
for(ir::io_inst *i: io)
|
||||||
if(auto *x = dynamic_cast<ir::store_inst*>(v))
|
extract_ld(i, axes);
|
||||||
tile_ty = x->get_operand(0)->get_type();
|
// update list of values to rematerialize
|
||||||
if(!tile_ty->is_tile_ty())
|
if(axes.empty())
|
||||||
return;
|
continue;
|
||||||
std::vector<unsigned> order(tile_ty->get_tile_shapes().size());
|
for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
|
||||||
std::iota(order.begin(), order.end(), 0);
|
remat.insert(remat.begin(),
|
||||||
order_[v] = order;
|
it->second.begin(), it->second.end());
|
||||||
if(ir::user* u = dynamic_cast<ir::user*>(v))
|
|
||||||
for(ir::value* op: u->ops())
|
|
||||||
set_order(op);
|
|
||||||
};
|
|
||||||
|
|
||||||
// initialize work-list
|
|
||||||
for(ir::function *fn: mod.get_function_list())
|
|
||||||
for(ir::basic_block *block: ir::cfg::reverse_post_order(fn))
|
|
||||||
for(ir::instruction *i: block->get_inst_list()){
|
|
||||||
if(auto *x = dynamic_cast<ir::io_inst*>(i)) {
|
|
||||||
ir::type* ptr_ty = x->get_pointer_operand()->get_type();
|
|
||||||
if(ptr_ty->is_tile_ty())
|
|
||||||
io.insert(x);
|
|
||||||
}
|
|
||||||
set_order(i);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rematerialize values
|
||||||
ir::builder &builder = mod.get_builder();
|
ir::builder &builder = mod.get_builder();
|
||||||
std::map<ir::value*, ir::value*> replaced;
|
for(ir::io_inst *r: remat) {
|
||||||
for(ir::io_inst *i: io) {
|
|
||||||
ir::value *ptr = i->get_pointer_operand();
|
|
||||||
auto max_contiguous = align_->contiguous(ptr);
|
|
||||||
std::vector<unsigned> order(max_contiguous.size());
|
|
||||||
std::iota(order.begin(), order.end(), 0);
|
|
||||||
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
|
|
||||||
std::list<std::pair<ir::instruction*, ir::instruction*>> work_list;
|
std::list<std::pair<ir::instruction*, ir::instruction*>> work_list;
|
||||||
if(order != order_[i])
|
std::map<ir::value*, ir::value*> replaced;
|
||||||
work_list.push_back({i, nullptr});
|
work_list.push_back({r, nullptr});
|
||||||
// rematerialize recursively
|
// rematerialize recursively
|
||||||
while(!work_list.empty()) {
|
while(!work_list.empty()) {
|
||||||
auto pair = work_list.back();
|
auto pair = work_list.back();
|
||||||
ir::instruction* cloned = pair.first;
|
ir::instruction* cloned = pair.first;
|
||||||
ir::instruction* original = pair.second;
|
ir::instruction* original = pair.second;
|
||||||
order_[cloned] = order;
|
|
||||||
work_list.pop_back();
|
work_list.pop_back();
|
||||||
for(ir::value *op: cloned->ops()) {
|
for(ir::value *op: cloned->ops()) {
|
||||||
ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
|
ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
|
||||||
@@ -101,14 +99,12 @@ void coalesce::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
n_op = builder.insert(n_op);
|
n_op = builder.insert(n_op);
|
||||||
replaced.insert({i_op, n_op});
|
replaced.insert({i_op, n_op});
|
||||||
order_[n_op] = order;
|
|
||||||
mem_->copy(n_op, i_op);
|
mem_->copy(n_op, i_op);
|
||||||
if(original)
|
if(original)
|
||||||
n_op->erase_use(original);
|
n_op->erase_use(original);
|
||||||
cloned->replace_uses_of_with(i_op, n_op);
|
cloned->replace_uses_of_with(i_op, n_op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -200,30 +200,30 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||||
// create passes
|
// create passes
|
||||||
codegen::analysis::meminfo shmem_info;
|
codegen::analysis::meminfo shmem_info;
|
||||||
codegen::analysis::align alignment_info;
|
codegen::analysis::align align;
|
||||||
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
||||||
codegen::transform::coalesce coalesce(&alignment_info, &shmem_info);
|
|
||||||
codegen::analysis::axes axes;
|
codegen::analysis::axes axes;
|
||||||
codegen::analysis::layout layouts(&axes);
|
codegen::analysis::layout layouts(&axes);
|
||||||
codegen::analysis::tiles tiles(opt.num_warps, &coalesce, &axes, &layouts);
|
codegen::transform::coalesce coalesce(&align, &layouts, &shmem_info);
|
||||||
|
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
|
||||||
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &tiles);
|
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &tiles);
|
||||||
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
||||||
codegen::transform::vectorize vectorize(&tiles);
|
codegen::transform::vectorize vectorize(&tiles);
|
||||||
codegen::transform::dce dce;
|
codegen::transform::dce dce;
|
||||||
codegen::transform::peephole peephole;
|
codegen::transform::peephole peephole;
|
||||||
codegen::transform::reassociate reassociate(&alignment_info);
|
codegen::transform::reassociate reassociate(&align);
|
||||||
codegen::selection selection(&shmem_allocation, &tiles, &shmem_info, &alignment_info, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
codegen::selection selection(&shmem_allocation, &tiles, &shmem_info, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
peephole.run(module);
|
peephole.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
alignment_info.run(module);
|
align.run(module);
|
||||||
shmem_info.run(module);
|
shmem_info.run(module);
|
||||||
coalesce.run(module);
|
|
||||||
dce.run(module);
|
|
||||||
axes.run(module);
|
axes.run(module);
|
||||||
layouts.run(module);
|
layouts.run(module);
|
||||||
|
coalesce.run(module);
|
||||||
|
align.run(module);
|
||||||
|
dce.run(module);
|
||||||
tiles.run(module);
|
tiles.run(module);
|
||||||
alignment_info.run(module);
|
|
||||||
reassociate.run(module);
|
reassociate.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
peephole.run(module);
|
peephole.run(module);
|
||||||
@@ -236,12 +236,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
dce.run(module);
|
dce.run(module);
|
||||||
vectorize.run(module);
|
vectorize.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
alignment_info.run(module);
|
|
||||||
coalesce.run(module);
|
|
||||||
dce.run(module);
|
|
||||||
// ir::print(module, std::cout);
|
|
||||||
axes.run(module);
|
axes.run(module);
|
||||||
layouts.run(module);
|
layouts.run(module);
|
||||||
|
align.run(module);
|
||||||
tiles.run(module);
|
tiles.run(module);
|
||||||
selection.run(module, *llvm);
|
selection.run(module, *llvm);
|
||||||
// return binary
|
// return binary
|
||||||
|
@@ -45,10 +45,10 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
|||||||
opt.defines.push_back({"TYPE", {ty}});
|
opt.defines.push_back({"TYPE", {ty}});
|
||||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||||
opt.defines.push_back({"TM", {"128"}});
|
opt.defines.push_back({"TM", {"64", "128"}});
|
||||||
opt.defines.push_back({"TN", {"128"}});
|
opt.defines.push_back({"TN", {"64", "128"}});
|
||||||
opt.defines.push_back({"TK", {"8"}});
|
opt.defines.push_back({"TK", {"8"}});
|
||||||
opt.num_warps = {4};
|
opt.num_warps = {2, 4, 8};
|
||||||
// create function
|
// create function
|
||||||
rt::function function(src::dot, opt);
|
rt::function function(src::dot, opt);
|
||||||
// benchmark available libraries
|
// benchmark available libraries
|
||||||
|
Reference in New Issue
Block a user