[codegen] more cleaning
This commit is contained in:
@@ -19,6 +19,20 @@ namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class axes;
|
||||
class align;
|
||||
|
||||
enum layout_type_t {
|
||||
HMMA_884,
|
||||
SCANLINE
|
||||
};
|
||||
|
||||
struct layout_t {
|
||||
layout_type_t type;
|
||||
ir::value *i;
|
||||
std::vector<int> axes;
|
||||
std::vector<unsigned> shapes;
|
||||
std::vector<int> order;
|
||||
};
|
||||
|
||||
class layout {
|
||||
typedef ir::value* node_t;
|
||||
@@ -31,19 +45,24 @@ private:
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layout(analysis::axes *axes);
|
||||
layout(analysis::axes *axes, analysis::align *align);
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const;
|
||||
const std::vector<ir::value*>& values_of(unsigned id) const;
|
||||
size_t num_layouts() const;
|
||||
layout_t get(ir::value *v) const;
|
||||
const std::map<size_t, layout_t>& get_all() const;
|
||||
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::axes* axes_;
|
||||
analysis::align* align_;
|
||||
tools::graph<ir::value*> graph_;
|
||||
std::map<ir::value*, size_t> groups_;
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
std::map<size_t, layout_t> layouts_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -22,6 +22,7 @@ namespace analysis{
|
||||
typedef unsigned slot_index;
|
||||
|
||||
class tiles;
|
||||
class layout;
|
||||
|
||||
struct segment {
|
||||
slot_index start;
|
||||
@@ -72,7 +73,7 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
liveness(tiles *t): tiles_(t){ }
|
||||
liveness(tiles *t, layout *l): tiles_(t), layouts_(l){ }
|
||||
// padding
|
||||
unsigned get_pad(ir::value *v) const { return pad_.at(v); }
|
||||
// buffer size
|
||||
@@ -92,6 +93,7 @@ public:
|
||||
private:
|
||||
// analysis
|
||||
tiles *tiles_;
|
||||
layout *layouts_;
|
||||
// stuff
|
||||
has_storage_map_t has_dedicated_storage_;
|
||||
indices_map_t indices;
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -25,28 +26,22 @@ class axes;
|
||||
class layout;
|
||||
class align;
|
||||
|
||||
enum layout_t {
|
||||
SCANLINE,
|
||||
HMMA_C
|
||||
};
|
||||
|
||||
class tiles {
|
||||
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
|
||||
private:
|
||||
void init_hmma_tile(ir::value *i);
|
||||
void init_scanline_tile(ir::value *i);
|
||||
void init_hmma_tile(const layout_t& layout);
|
||||
void init_scanline_tile(const layout_t& layout);
|
||||
bool is_trans(ir::value *i);
|
||||
|
||||
public:
|
||||
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
|
||||
void run(ir::module &mod);
|
||||
layout_t hmma(ir::value *value);
|
||||
int mts(ir::value *value, unsigned ax);
|
||||
int nts(ir::value *value, unsigned ax);
|
||||
int fpw(ir::value *value, unsigned ax);
|
||||
int wpt(ir::value *value, unsigned ax);
|
||||
std::vector<int> order(ir::value *v);
|
||||
const std::map<int, ir::value*>& largest();
|
||||
|
||||
|
||||
private:
|
||||
// dependencies
|
||||
@@ -56,9 +51,6 @@ private:
|
||||
// number of warps
|
||||
size_t num_warps_;
|
||||
// tile properties
|
||||
std::map<int, ir::value*> largest_;
|
||||
std::map<int, std::vector<int>> order_;
|
||||
std::map<int, layout_t> hmma_;
|
||||
std::map<int, int> fpw_;
|
||||
std::map<int, int> wpt_;
|
||||
std::map<int, int> mts_;
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
|
||||
|
||||
@@ -171,9 +172,9 @@ private:
|
||||
void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr);
|
||||
void create_distributed_tile(ir::value *v, Builder &builder);
|
||||
void create_tile(ir::value *v, Builder &builder, std::set<ir::value *> &seen, Value *sh_mem_ptr);
|
||||
void init_strided_scan_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_hmma_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_strided_scan_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_hmma_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
|
||||
void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
|
||||
|
||||
// lower scalar instruction
|
||||
|
@@ -1,6 +1,8 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -12,8 +14,8 @@ namespace analysis{
|
||||
|
||||
|
||||
// constructor
|
||||
layout::layout(analysis::axes *axes)
|
||||
: axes_(axes) { }
|
||||
layout::layout(analysis::axes *axes, analysis::align *align)
|
||||
: axes_(axes), align_(align) { }
|
||||
|
||||
// get group id
|
||||
unsigned layout::layout_of(ir::value *value) const
|
||||
@@ -56,6 +58,51 @@ void layout::make_graph(ir::instruction *i) {
|
||||
}
|
||||
}
|
||||
|
||||
// hmma
|
||||
bool is_hmma_c(ir::value *v){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
ir::type *a_ty = a->get_type();
|
||||
ir::value *b = x->get_operand(1);
|
||||
ir::type *b_ty = b->get_type();
|
||||
result = a_ty->get_scalar_ty()->is_half_ty() &&
|
||||
b_ty->get_scalar_ty()->is_half_ty();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
layout_t layout::get(ir::value *v) const {
|
||||
return layouts_.at(groups_.at(v));
|
||||
}
|
||||
|
||||
const std::map<size_t, layout_t>& layout::get_all() const {
|
||||
return layouts_;
|
||||
}
|
||||
|
||||
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
if(i && i->get_pointer_operand() == v)
|
||||
result.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
return true;
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
|
||||
bool result = true;
|
||||
for(ir::value *op: phi->ops())
|
||||
result = result && is_trans(op);
|
||||
return result;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void layout::run(ir::module &mod) {
|
||||
// make graph
|
||||
graph_.clear();
|
||||
@@ -64,6 +111,82 @@ void layout::run(ir::module &mod) {
|
||||
});
|
||||
// connected components
|
||||
graph_.connected_components(&values_, &groups_);
|
||||
// create layouts
|
||||
for(const auto& x: values_) {
|
||||
bool hmma_c = std::any_of(x.second.begin(), x.second.end(), &is_hmma_c);
|
||||
layouts_[x.first].type = hmma_c ? HMMA_884 : SCANLINE;
|
||||
|
||||
}
|
||||
|
||||
|
||||
/* ---- TO CLEAN ---- */
|
||||
|
||||
size_t num_groups = num_layouts();
|
||||
// helpers
|
||||
auto rank = [this](ir::value* v) {
|
||||
int ret = 0;
|
||||
for(int s: v->get_type()->get_tile_shapes())
|
||||
ret += s > 1;
|
||||
return ret;
|
||||
};
|
||||
|
||||
// find out which value is the largest in each group
|
||||
for(const auto& x: values_) {
|
||||
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
|
||||
ir::value *largest = *std::max_element(x.second.begin(), x.second.end(), cmp);
|
||||
layouts_[x.first].axes = axes_->get(largest);
|
||||
layouts_[x.first].i = largest;
|
||||
layouts_[x.first].shapes = largest->get_type()->get_tile_shapes();
|
||||
}
|
||||
|
||||
|
||||
// find out the layout ordering of a group
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::set<ir::io_inst*> io;
|
||||
for(ir::value* v: values_of(i))
|
||||
extract_io_use(v, io);
|
||||
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
|
||||
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
|
||||
};
|
||||
auto it = std::max_element(io.begin(), io.end(), cmp);
|
||||
std::vector<int> order(layouts_[i].axes.size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
if(it != io.end()) {
|
||||
auto max_contiguous = align_->contiguous((*it)->get_pointer_operand());
|
||||
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
|
||||
return max_contiguous[a] > max_contiguous[b]; }
|
||||
);
|
||||
}
|
||||
layouts_[i].order = order;
|
||||
}
|
||||
// matrix multiplication optimizations
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::vector<ir::dot_inst*> dots;
|
||||
for(ir::value* v: values_of(i))
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
|
||||
dots.push_back(x);
|
||||
for(ir::dot_inst* dot: dots){
|
||||
ir::value* a = dot->get_operand(0);
|
||||
ir::value* b = dot->get_operand(1);
|
||||
if(get(dot).type == HMMA_884){
|
||||
auto a_val = values_of(layout_of(a));
|
||||
auto b_val = values_of(layout_of(b));
|
||||
for(ir::value *v: a_val)
|
||||
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
layouts_[layout_of(a)].order = layouts_[layout_of(cts->get_operand(0))].order;
|
||||
for(ir::value *v: b_val)
|
||||
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
layouts_[layout_of(b)].order = layouts_[layout_of(cts->get_operand(0))].order;
|
||||
}
|
||||
else{
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
layouts_[layout_of(a)].order = is_trans(a) ? row : col;
|
||||
layouts_[layout_of(b)].order = is_trans(b) ? col : row;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@
|
||||
#include "triton/codegen/instructions.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/function.h"
|
||||
@@ -89,8 +90,8 @@ bool liveness::do_pad(ir::value *x) {
|
||||
ir::value *b = dot->get_operand(1);
|
||||
size_t a_previous = pad_[a];
|
||||
size_t b_previous = pad_[b];
|
||||
auto a_order = tiles_->order(a);
|
||||
auto b_order = tiles_->order(b);
|
||||
auto a_order = layouts_->get(a).order;
|
||||
auto b_order = layouts_->get(b).order;
|
||||
bool a_row = is_trans(a) ^ (a_order[0] == 1);
|
||||
bool b_row = is_trans(b) ^ (b_order[0] == 1);
|
||||
auto a_shapes = a->get_type()->get_tile_shapes();
|
||||
@@ -108,9 +109,9 @@ bool liveness::do_pad(ir::value *x) {
|
||||
}
|
||||
// padding for copy to shared
|
||||
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
|
||||
auto cts_order = tiles_->order(cts);
|
||||
auto cts_order = layouts_->get(cts).order;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
auto arg_order = tiles_->order(arg);
|
||||
auto arg_order = layouts_->get(arg).order;
|
||||
size_t previous = pad_[cts];
|
||||
if(cts_order != arg_order)
|
||||
pad_[cts] = std::max<int>(pad_[cts], 4);
|
||||
@@ -144,7 +145,7 @@ unsigned liveness::num_bytes(ir::value *x) {
|
||||
for(auto x: shapes)
|
||||
num_elements *= x;
|
||||
size_t depth;
|
||||
if(tiles_->hmma(x))
|
||||
if(layouts_->get(x).type == HMMA_884)
|
||||
depth = tiles_->wpt(op, axis);
|
||||
else
|
||||
depth = tiles_->mts(op, axis);
|
||||
@@ -153,7 +154,7 @@ unsigned liveness::num_bytes(ir::value *x) {
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
unsigned pad = pad_.at(x);
|
||||
if(pad > 0){
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[layouts_->get(x).order[0]];
|
||||
num_bytes += pad * num_bytes / ld;
|
||||
}
|
||||
if(has_double(x))
|
||||
|
@@ -23,59 +23,7 @@ tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, ana
|
||||
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
|
||||
{ }
|
||||
|
||||
bool is_hmma_c(ir::value *v){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
ir::type *a_ty = a->get_type();
|
||||
ir::value *b = x->get_operand(1);
|
||||
ir::type *b_ty = b->get_type();
|
||||
result = a_ty->get_scalar_ty()->is_half_ty() &&
|
||||
b_ty->get_scalar_ty()->is_half_ty();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool is_hmma_a_col(ir::value* v) {
|
||||
for(ir::user *u: v->get_users())
|
||||
if(is_hmma_c(u)){
|
||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||
if((v == dot->get_operand(0)))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_hmma_a_row(ir::value* v) {
|
||||
for(ir::user *u: v->get_users())
|
||||
if(is_hmma_c(u)){
|
||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||
if((v == dot->get_operand(0)))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_hmma_b_col(ir::value* v) {
|
||||
for(ir::user *u: v->get_users())
|
||||
if(is_hmma_c(u)){
|
||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||
if((v == dot->get_operand(1)))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_hmma_b_row(ir::value* v) {
|
||||
for(ir::user *u: v->get_users())
|
||||
if(is_hmma_c(u)){
|
||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||
if((v == dot->get_operand(1)))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
layout_t tiles::hmma(ir::value *value) {
|
||||
return hmma_.at(layout_->layout_of(value));
|
||||
}
|
||||
|
||||
int tiles::mts(ir::value *value, unsigned ax) {
|
||||
return mts_.at(axes_->get(value, ax));
|
||||
@@ -93,24 +41,15 @@ int tiles::wpt(ir::value *value, unsigned ax) {
|
||||
return wpt_.at(axes_->get(value, ax));
|
||||
}
|
||||
|
||||
std::vector<int> tiles::order(ir::value *v) {
|
||||
auto ret = order_[layout_->layout_of(v)];
|
||||
return ret;
|
||||
}
|
||||
|
||||
const std::map<int, ir::value*>& tiles::largest() {
|
||||
return largest_;
|
||||
}
|
||||
|
||||
|
||||
unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
||||
return std::min(std::max(x, lo), hi);
|
||||
}
|
||||
|
||||
|
||||
void tiles::init_hmma_tile(ir::value *i) {
|
||||
auto ord = order(i);
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
void tiles::init_hmma_tile(const layout_t& layout) {
|
||||
auto ord = layout.order;
|
||||
auto shapes = layout.i->get_type()->get_tile_shapes();
|
||||
unsigned shape_0 = shapes[ord[0]];
|
||||
unsigned shape_1 = shapes[ord[1]];
|
||||
/* fragments per warp */
|
||||
@@ -127,7 +66,7 @@ void tiles::init_hmma_tile(ir::value *i) {
|
||||
}while(fpw_nm1 != fpw);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
fpw_[axes_->get(i, d)] = fpw[d];
|
||||
fpw_[layout.axes[d]] = fpw[d];
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> wpt = {1, 1, 1};
|
||||
@@ -141,149 +80,48 @@ void tiles::init_hmma_tile(ir::value *i) {
|
||||
}while(wpt_nm1 != wpt);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
wpt_[axes_->get(i, d)] = wpt[d];
|
||||
wpt_[layout.axes[d]] = wpt[d];
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_warps *= wpt_[axes_->get(i, d)];
|
||||
effective_num_warps *= wpt_[layout.axes[d]];
|
||||
if(num_warps_ != effective_num_warps)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
void tiles::init_scanline_tile(ir::value *i) {
|
||||
auto ord = order(i);
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
unsigned size = i->get_type()->get_tile_num_elements();
|
||||
void tiles::init_scanline_tile(const layout_t& layout) {
|
||||
auto ord = layout.order;
|
||||
auto shapes = layout.shapes;
|
||||
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
|
||||
unsigned ld = ord[0];
|
||||
unsigned num_threads = num_warps_*32;
|
||||
unsigned current = num_threads;
|
||||
nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4);
|
||||
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]);
|
||||
current = current / mts_[axes_->get(i, ld)];
|
||||
nts_[layout.axes[ld]] = clamp(size / num_threads, 1, 4);
|
||||
mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld] / nts_[layout.axes[ld]]);
|
||||
current = current / mts_[layout.axes[ld]];
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
ld = ord[d];
|
||||
nts_[axes_->get(i, ld)] = 1;
|
||||
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]);
|
||||
current = current / mts_[axes_->get(i, ld)];
|
||||
nts_[layout.axes[ld]] = 1;
|
||||
mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld]);
|
||||
current = current / mts_[layout.axes[ld]];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_threads *= mts_[axes_->get(i, d)];
|
||||
effective_num_threads *= mts_[layout.axes[d]];
|
||||
// std::cout << num_threads << " " << effective_num_threads << std::endl;
|
||||
if(num_threads != effective_num_threads)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
if(i && i->get_pointer_operand() == v)
|
||||
result.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool tiles::is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
return true;
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
|
||||
bool result = true;
|
||||
for(ir::value *op: phi->ops())
|
||||
result = result && is_trans(op);
|
||||
return result;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void tiles::run(ir::module &) {
|
||||
hmma_.clear();
|
||||
largest_.clear();
|
||||
order_.clear();
|
||||
|
||||
size_t num_groups = layout_->num_layouts();
|
||||
// helpers
|
||||
auto rank = [](ir::value* v) {
|
||||
int ret = 0;
|
||||
for(int s: v->get_type()->get_tile_shapes())
|
||||
ret += s > 1;
|
||||
return ret;
|
||||
};
|
||||
|
||||
// find out which groups require hmma layout
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values_of(i);
|
||||
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
|
||||
if(hmma_c) hmma_[i] = HMMA_C;
|
||||
else hmma_[i] = SCANLINE;
|
||||
}
|
||||
|
||||
// find out which value is the largest in each group
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values_of(i);
|
||||
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
|
||||
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
|
||||
}
|
||||
|
||||
|
||||
// find out the layout ordering of a group
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::set<ir::io_inst*> io;
|
||||
for(ir::value* v: layout_->values_of(i))
|
||||
extract_io_use(v, io);
|
||||
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
|
||||
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
|
||||
};
|
||||
auto it = std::max_element(io.begin(), io.end(), cmp);
|
||||
std::vector<int> order(rank(largest_[i]));
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
if(it != io.end()) {
|
||||
auto max_contiguous = align_->contiguous((*it)->get_pointer_operand());
|
||||
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
|
||||
return max_contiguous[a] > max_contiguous[b]; }
|
||||
);
|
||||
}
|
||||
order_[i] = order;
|
||||
}
|
||||
// matrix multiplication optimizations
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::vector<ir::dot_inst*> dots;
|
||||
for(ir::value* v: layout_->values_of(i))
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
|
||||
dots.push_back(x);
|
||||
for(ir::dot_inst* dot: dots){
|
||||
ir::value* a = dot->get_operand(0);
|
||||
ir::value* b = dot->get_operand(1);
|
||||
if(hmma_.at(layout_->layout_of(dot)) == HMMA_C){
|
||||
auto a_val = layout_->values_of(layout_->layout_of(a));
|
||||
auto b_val = layout_->values_of(layout_->layout_of(b));
|
||||
for(ir::value *v: a_val)
|
||||
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
order_[layout_->layout_of(a)] = order_[layout_->layout_of(cts->get_operand(0))];
|
||||
for(ir::value *v: b_val)
|
||||
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
order_[layout_->layout_of(b)] = order_[layout_->layout_of(cts->get_operand(0))];
|
||||
}
|
||||
else{
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
order_[layout_->layout_of(a)] = is_trans(a) ? row : col;
|
||||
order_[layout_->layout_of(b)] = is_trans(b) ? col : row;
|
||||
}
|
||||
}
|
||||
}
|
||||
// tiling parameters
|
||||
for(auto x: largest_){
|
||||
ir::value *i = x.second;
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
for(auto x: layout_->get_all()){
|
||||
/* HMMA parameters*/
|
||||
if(hmma_[x.first] == HMMA_C)
|
||||
init_hmma_tile(i);
|
||||
if(x.second.type == HMMA_884)
|
||||
init_hmma_tile(x.second);
|
||||
else
|
||||
init_scanline_tile(i);
|
||||
init_scanline_tile(x.second);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -577,37 +577,36 @@ inline int32_t ceil(int32_t num, int32_t div){
|
||||
return (num + div - 1)/div;
|
||||
}
|
||||
|
||||
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
auto order = tiles_->order(v);
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
auto order = layout.order;
|
||||
const auto& shapes = layout.shapes;
|
||||
size_t dim = shapes.size();
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
std::vector<unsigned> block_size(dim);
|
||||
std::vector<unsigned> nts(dim);
|
||||
std::vector<unsigned> mts(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
contiguous[i] = tiles_->nts(v, i);
|
||||
block_size[i] = tiles_->mts(v, i);
|
||||
nts[i] = tiles_->nts(layout.i, i);
|
||||
mts[i] = tiles_->mts(layout.i, i);
|
||||
}
|
||||
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
|
||||
std::vector<Value*> thread_id = delinearize(full_thread_id, order, block_size, builder);
|
||||
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, builder);
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = builder.getInt32(contiguous[k]);
|
||||
Value *contiguous_k = builder.getInt32(nts[k]);
|
||||
Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k);
|
||||
unsigned per_block = contiguous[k] * block_size[k];
|
||||
unsigned per_thread = contiguous[k] * shapes[k] / per_block;
|
||||
unsigned per_block = nts[k] * mts[k];
|
||||
unsigned per_thread = nts[k] * shapes[k] / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||
unsigned offset = n / nts[k] * per_block + n % nts[k];
|
||||
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[a_axes_->get(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id[k]};
|
||||
axes_[layout.axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
|
||||
}
|
||||
}
|
||||
|
||||
void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
// auto order = reorder_->get_order(v);
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
const auto& shapes = layout.shapes;
|
||||
if(shapes.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
bool is_batched = shapes.size() >= 3;
|
||||
@@ -619,13 +618,13 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
|
||||
Value *_16 = builder.getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = tiles_->fpw(v, 0);
|
||||
unsigned fpw_1 = tiles_->fpw(v, 1);
|
||||
unsigned fpw_2 = is_batched ? tiles_->fpw(v, 2) : 1;
|
||||
unsigned fpw_0 = tiles_->fpw(layout.i, 0);
|
||||
unsigned fpw_1 = tiles_->fpw(layout.i, 1);
|
||||
unsigned fpw_2 = is_batched ? tiles_->fpw(layout.i, 2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = tiles_->wpt(v, 0);
|
||||
unsigned wpt_1 = tiles_->wpt(v, 1);
|
||||
unsigned wpt_2 = is_batched ? tiles_->wpt(v, 2) : 1;
|
||||
unsigned wpt_0 = tiles_->wpt(layout.i, 0);
|
||||
unsigned wpt_1 = tiles_->wpt(layout.i, 1);
|
||||
unsigned wpt_2 = is_batched ? tiles_->wpt(layout.i, 2) : 1;
|
||||
// hmma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
@@ -706,18 +705,18 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[a_axes_->get(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[a_axes_->get(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
axes_[layout.axes[0]] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[layout.axes[1]] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[a_axes_->get(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
axes_[layout.axes[2]] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
if(tiles_->hmma(v) == analysis::HMMA_C)
|
||||
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
|
||||
void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
if(layout.type == analysis::HMMA_884)
|
||||
init_hmma_axes(layout, builder, u_thread_id, u_warp_id);
|
||||
else
|
||||
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
|
||||
init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id);
|
||||
}
|
||||
|
||||
/* -------------------
|
||||
@@ -727,7 +726,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
|
||||
if(tmap_.find(v) != tmap_.end())
|
||||
return;
|
||||
auto order = tiles_->order(v);
|
||||
auto order = layouts_->get(v).order;
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
unsigned pad = liveness_->get_pad(v);
|
||||
if(pad > 0)
|
||||
@@ -777,7 +776,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
||||
axes[d].values = {builder.getInt32(0)};
|
||||
}
|
||||
}
|
||||
distributed_tile *T = new distributed_tile(ty, shapes, tiles_->order(v), axes, builder, false);
|
||||
distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v).order, axes, builder, false);
|
||||
bool is_inserted = tmap_.insert({v, T}).second;
|
||||
// constant range
|
||||
if(is_inserted && dynamic_cast<ir::make_range*>(v)){
|
||||
@@ -820,7 +819,7 @@ void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_m
|
||||
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
|
||||
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
|
||||
// create grid
|
||||
for(auto x: tiles_->largest())
|
||||
for(auto x: layouts_->get_all())
|
||||
init_axes(x.second, builder, u_thread_warp_id, u_warp_id);
|
||||
// create tile
|
||||
std::set<ir::value*> seen;
|
||||
@@ -868,7 +867,7 @@ void selection::lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, F
|
||||
void selection::lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *scalars = tmap_.at(x->get_value_operand());
|
||||
// size_t ld = tiles_->order(x->get_pointer_operand())[0];
|
||||
// size_t ld = layouts_->order(x->get_pointer_operand())[0];
|
||||
// unsigned vector_size = 2;
|
||||
// // vectorize pointers
|
||||
// std::map<unsigned, Value*> ptr_packets;
|
||||
@@ -1015,9 +1014,9 @@ void selection::lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Functio
|
||||
|
||||
void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
unsigned vector_size = 1;
|
||||
auto x_order = tiles_->order(x);
|
||||
auto x_order = layouts_->get(x).order;
|
||||
ir::value *arg = x->get_operand(0);
|
||||
auto arg_order = tiles_->order(arg);
|
||||
auto arg_order = layouts_->get(arg).order;
|
||||
// tiles
|
||||
shared_tile* result = (shared_tile*)tmap_.at(x);
|
||||
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
|
||||
@@ -1092,8 +1091,8 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
|
||||
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
|
||||
|
||||
auto ord_a = tiles_->order(dot->get_operand(0));
|
||||
auto ord_b = tiles_->order(dot->get_operand(1));
|
||||
auto ord_a = layouts_->get(dot->get_operand(0)).order;
|
||||
auto ord_b = layouts_->get(dot->get_operand(1)).order;
|
||||
|
||||
bool is_a_trans = is_trans(dot->get_operand(0));
|
||||
bool is_b_trans = is_trans(dot->get_operand(1));
|
||||
@@ -1255,7 +1254,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
|
||||
if(NK != 1) {
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
if(tiles_->hmma(dot) == analysis::HMMA_C)
|
||||
if(layouts_->get(dot).type == analysis::HMMA_884)
|
||||
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
|
||||
else
|
||||
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||
@@ -1271,7 +1270,7 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
|
||||
// find vector size
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = tiles_->order(ptr)[0];
|
||||
size_t ld = layouts_->get(ptr).order[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
@@ -1343,7 +1342,7 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
// find vector size
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
size_t ld = tiles_->order(ptr)[0];
|
||||
size_t ld = layouts_->get(ptr).order[0];
|
||||
unsigned alignment = alignment_->get(ptr, ld);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
|
@@ -83,8 +83,7 @@ void coalesce::run(ir::module &mod) {
|
||||
if(axes.empty())
|
||||
continue;
|
||||
for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
|
||||
remat.insert(remat.begin(),
|
||||
it->second.begin(), it->second.end());
|
||||
remat.insert(remat.begin(), it->second.begin(), it->second.end());
|
||||
}
|
||||
// rematerialize values
|
||||
for(ir::io_inst *r: remat) {
|
||||
|
@@ -202,9 +202,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::analysis::layout layouts(&axes);
|
||||
codegen::analysis::layout layouts(&axes, &align);
|
||||
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
|
||||
codegen::analysis::liveness liveness(&tiles);
|
||||
codegen::analysis::liveness liveness(&tiles, &layouts);
|
||||
codegen::analysis::allocation allocation(&liveness, &tiles);
|
||||
codegen::transform::membar barriers(&liveness, &allocation);
|
||||
codegen::transform::dce dce;
|
||||
|
Reference in New Issue
Block a user