[codegen] more cleaning

This commit is contained in:
Philippe Tillet
2019-10-09 15:05:44 -04:00
parent 10ab94d1c5
commit 9bc6df4fd1
10 changed files with 226 additions and 252 deletions

View File

@@ -19,6 +19,20 @@ namespace codegen{
namespace analysis{
class axes;
class align;
enum layout_type_t {
HMMA_884,
SCANLINE
};
struct layout_t {
layout_type_t type;
ir::value *i;
std::vector<int> axes;
std::vector<unsigned> shapes;
std::vector<int> order;
};
class layout {
typedef ir::value* node_t;
@@ -31,19 +45,24 @@ private:
public:
// constructor
layout(analysis::axes *axes);
layout(analysis::axes *axes, analysis::align *align);
// accessors
unsigned layout_of(ir::value *value) const;
const std::vector<ir::value*>& values_of(unsigned id) const;
size_t num_layouts() const;
layout_t get(ir::value *v) const;
const std::map<size_t, layout_t>& get_all() const;
// execution
void run(ir::module &mod);
private:
analysis::axes* axes_;
analysis::align* align_;
tools::graph<ir::value*> graph_;
std::map<ir::value*, size_t> groups_;
std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, layout_t> layouts_;
};
}

View File

@@ -22,6 +22,7 @@ namespace analysis{
typedef unsigned slot_index;
class tiles;
class layout;
struct segment {
slot_index start;
@@ -72,7 +73,7 @@ private:
public:
liveness(tiles *t): tiles_(t){ }
liveness(tiles *t, layout *l): tiles_(t), layouts_(l){ }
// padding
unsigned get_pad(ir::value *v) const { return pad_.at(v); }
// buffer size
@@ -92,6 +93,7 @@ public:
private:
// analysis
tiles *tiles_;
layout *layouts_;
// stuff
has_storage_map_t has_dedicated_storage_;
indices_map_t indices;

View File

@@ -5,6 +5,7 @@
#include <set>
#include <vector>
#include <memory>
#include "triton/codegen/analysis/layout.h"
namespace triton{
@@ -25,28 +26,22 @@ class axes;
class layout;
class align;
enum layout_t {
SCANLINE,
HMMA_C
};
class tiles {
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
private:
void init_hmma_tile(ir::value *i);
void init_scanline_tile(ir::value *i);
void init_hmma_tile(const layout_t& layout);
void init_scanline_tile(const layout_t& layout);
bool is_trans(ir::value *i);
public:
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
void run(ir::module &mod);
layout_t hmma(ir::value *value);
int mts(ir::value *value, unsigned ax);
int nts(ir::value *value, unsigned ax);
int fpw(ir::value *value, unsigned ax);
int wpt(ir::value *value, unsigned ax);
std::vector<int> order(ir::value *v);
const std::map<int, ir::value*>& largest();
private:
// dependencies
@@ -56,9 +51,6 @@ private:
// number of warps
size_t num_warps_;
// tile properties
std::map<int, ir::value*> largest_;
std::map<int, std::vector<int>> order_;
std::map<int, layout_t> hmma_;
std::map<int, int> fpw_;
std::map<int, int> wpt_;
std::map<int, int> mts_;

View File

@@ -5,6 +5,7 @@
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/transform/cts.h"
@@ -171,9 +172,9 @@ private:
void create_shared_tile(ir::value *v, Builder &builder, Value *sh_mem_ptr);
void create_distributed_tile(ir::value *v, Builder &builder);
void create_tile(ir::value *v, Builder &builder, std::set<ir::value *> &seen, Value *sh_mem_ptr);
void init_strided_scan_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_hmma_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_strided_scan_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_hmma_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_axes(const analysis::layout_t& layout, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
// lower scalar instruction

View File

@@ -1,6 +1,8 @@
#include <algorithm>
#include <iostream>
#include <numeric>
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
@@ -12,8 +14,8 @@ namespace analysis{
// constructor
layout::layout(analysis::axes *axes)
: axes_(axes) { }
layout::layout(analysis::axes *axes, analysis::align *align)
: axes_(axes), align_(align) { }
// get group id
unsigned layout::layout_of(ir::value *value) const
@@ -56,6 +58,51 @@ void layout::make_graph(ir::instruction *i) {
}
}
// hmma
bool is_hmma_c(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = a_ty->get_scalar_ty()->is_half_ty() &&
b_ty->get_scalar_ty()->is_half_ty();
}
return result;
}
layout_t layout::get(ir::value *v) const {
return layouts_.at(groups_.at(v));
}
const std::map<size_t, layout_t>& layout::get_all() const {
return layouts_;
}
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(i);
}
}
inline bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
void layout::run(ir::module &mod) {
// make graph
graph_.clear();
@@ -64,6 +111,82 @@ void layout::run(ir::module &mod) {
});
// connected components
graph_.connected_components(&values_, &groups_);
// create layouts
for(const auto& x: values_) {
bool hmma_c = std::any_of(x.second.begin(), x.second.end(), &is_hmma_c);
layouts_[x.first].type = hmma_c ? HMMA_884 : SCANLINE;
}
/* ---- TO CLEAN ---- */
size_t num_groups = num_layouts();
// helpers
auto rank = [this](ir::value* v) {
int ret = 0;
for(int s: v->get_type()->get_tile_shapes())
ret += s > 1;
return ret;
};
// find out which value is the largest in each group
for(const auto& x: values_) {
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
ir::value *largest = *std::max_element(x.second.begin(), x.second.end(), cmp);
layouts_[x.first].axes = axes_->get(largest);
layouts_[x.first].i = largest;
layouts_[x.first].shapes = largest->get_type()->get_tile_shapes();
}
// find out the layout ordering of a group
for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io;
for(ir::value* v: values_of(i))
extract_io_use(v, io);
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
};
auto it = std::max_element(io.begin(), io.end(), cmp);
std::vector<int> order(layouts_[i].axes.size());
std::iota(order.begin(), order.end(), 0);
if(it != io.end()) {
auto max_contiguous = align_->contiguous((*it)->get_pointer_operand());
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b]; }
);
}
layouts_[i].order = order;
}
// matrix multiplication optimizations
for(size_t i = 0; i < num_groups; i++){
std::vector<ir::dot_inst*> dots;
for(ir::value* v: values_of(i))
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
dots.push_back(x);
for(ir::dot_inst* dot: dots){
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
if(get(dot).type == HMMA_884){
auto a_val = values_of(layout_of(a));
auto b_val = values_of(layout_of(b));
for(ir::value *v: a_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
layouts_[layout_of(a)].order = layouts_[layout_of(cts->get_operand(0))].order;
for(ir::value *v: b_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
layouts_[layout_of(b)].order = layouts_[layout_of(cts->get_operand(0))].order;
}
else{
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
layouts_[layout_of(a)].order = is_trans(a) ? row : col;
layouts_[layout_of(b)].order = is_trans(b) ? col : row;
}
}
}
}
}

View File

@@ -4,6 +4,7 @@
#include "triton/codegen/instructions.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/tiles.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/transform/cts.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/function.h"
@@ -89,8 +90,8 @@ bool liveness::do_pad(ir::value *x) {
ir::value *b = dot->get_operand(1);
size_t a_previous = pad_[a];
size_t b_previous = pad_[b];
auto a_order = tiles_->order(a);
auto b_order = tiles_->order(b);
auto a_order = layouts_->get(a).order;
auto b_order = layouts_->get(b).order;
bool a_row = is_trans(a) ^ (a_order[0] == 1);
bool b_row = is_trans(b) ^ (b_order[0] == 1);
auto a_shapes = a->get_type()->get_tile_shapes();
@@ -108,9 +109,9 @@ bool liveness::do_pad(ir::value *x) {
}
// padding for copy to shared
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
auto cts_order = tiles_->order(cts);
auto cts_order = layouts_->get(cts).order;
ir::value *arg = cts->get_operand(0);
auto arg_order = tiles_->order(arg);
auto arg_order = layouts_->get(arg).order;
size_t previous = pad_[cts];
if(cts_order != arg_order)
pad_[cts] = std::max<int>(pad_[cts], 4);
@@ -144,7 +145,7 @@ unsigned liveness::num_bytes(ir::value *x) {
for(auto x: shapes)
num_elements *= x;
size_t depth;
if(tiles_->hmma(x))
if(layouts_->get(x).type == HMMA_884)
depth = tiles_->wpt(op, axis);
else
depth = tiles_->mts(op, axis);
@@ -153,7 +154,7 @@ unsigned liveness::num_bytes(ir::value *x) {
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = pad_.at(x);
if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
unsigned ld = x->get_type()->get_tile_shapes()[layouts_->get(x).order[0]];
num_bytes += pad * num_bytes / ld;
}
if(has_double(x))

View File

@@ -23,59 +23,7 @@ tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, ana
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
{ }
bool is_hmma_c(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = a_ty->get_scalar_ty()->is_half_ty() &&
b_ty->get_scalar_ty()->is_half_ty();
}
return result;
}
bool is_hmma_a_col(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(0)))
return true;
}
}
bool is_hmma_a_row(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(0)))
return true;
}
}
bool is_hmma_b_col(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(1)))
return true;
}
}
bool is_hmma_b_row(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(1)))
return true;
}
}
layout_t tiles::hmma(ir::value *value) {
return hmma_.at(layout_->layout_of(value));
}
int tiles::mts(ir::value *value, unsigned ax) {
return mts_.at(axes_->get(value, ax));
@@ -93,24 +41,15 @@ int tiles::wpt(ir::value *value, unsigned ax) {
return wpt_.at(axes_->get(value, ax));
}
std::vector<int> tiles::order(ir::value *v) {
auto ret = order_[layout_->layout_of(v)];
return ret;
}
const std::map<int, ir::value*>& tiles::largest() {
return largest_;
}
unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
return std::min(std::max(x, lo), hi);
}
void tiles::init_hmma_tile(ir::value *i) {
auto ord = order(i);
auto shapes = i->get_type()->get_tile_shapes();
void tiles::init_hmma_tile(const layout_t& layout) {
auto ord = layout.order;
auto shapes = layout.i->get_type()->get_tile_shapes();
unsigned shape_0 = shapes[ord[0]];
unsigned shape_1 = shapes[ord[1]];
/* fragments per warp */
@@ -127,7 +66,7 @@ void tiles::init_hmma_tile(ir::value *i) {
}while(fpw_nm1 != fpw);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
fpw_[axes_->get(i, d)] = fpw[d];
fpw_[layout.axes[d]] = fpw[d];
/* warps per tile */
// try to make things as square as possible to maximize data re-use
std::vector<unsigned> wpt = {1, 1, 1};
@@ -141,149 +80,48 @@ void tiles::init_hmma_tile(ir::value *i) {
}while(wpt_nm1 != wpt);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
wpt_[axes_->get(i, d)] = wpt[d];
wpt_[layout.axes[d]] = wpt[d];
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_warps *= wpt_[axes_->get(i, d)];
effective_num_warps *= wpt_[layout.axes[d]];
if(num_warps_ != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
void tiles::init_scanline_tile(ir::value *i) {
auto ord = order(i);
auto shapes = i->get_type()->get_tile_shapes();
unsigned size = i->get_type()->get_tile_num_elements();
void tiles::init_scanline_tile(const layout_t& layout) {
auto ord = layout.order;
auto shapes = layout.shapes;
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
unsigned ld = ord[0];
unsigned num_threads = num_warps_*32;
unsigned current = num_threads;
nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4);
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]);
current = current / mts_[axes_->get(i, ld)];
nts_[layout.axes[ld]] = clamp(size / num_threads, 1, 4);
mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld] / nts_[layout.axes[ld]]);
current = current / mts_[layout.axes[ld]];
for(size_t d = 1; d < shapes.size(); d++){
ld = ord[d];
nts_[axes_->get(i, ld)] = 1;
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]);
current = current / mts_[axes_->get(i, ld)];
nts_[layout.axes[ld]] = 1;
mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld]);
current = current / mts_[layout.axes[ld]];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_threads *= mts_[axes_->get(i, d)];
effective_num_threads *= mts_[layout.axes[d]];
// std::cout << num_threads << " " << effective_num_threads << std::endl;
if(num_threads != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(i);
}
}
bool tiles::is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
}
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
bool result = true;
for(ir::value *op: phi->ops())
result = result && is_trans(op);
return result;
}
return false;
}
void tiles::run(ir::module &) {
hmma_.clear();
largest_.clear();
order_.clear();
size_t num_groups = layout_->num_layouts();
// helpers
auto rank = [](ir::value* v) {
int ret = 0;
for(int s: v->get_type()->get_tile_shapes())
ret += s > 1;
return ret;
};
// find out which groups require hmma layout
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values_of(i);
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
if(hmma_c) hmma_[i] = HMMA_C;
else hmma_[i] = SCANLINE;
}
// find out which value is the largest in each group
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values_of(i);
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
}
// find out the layout ordering of a group
for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io;
for(ir::value* v: layout_->values_of(i))
extract_io_use(v, io);
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
};
auto it = std::max_element(io.begin(), io.end(), cmp);
std::vector<int> order(rank(largest_[i]));
std::iota(order.begin(), order.end(), 0);
if(it != io.end()) {
auto max_contiguous = align_->contiguous((*it)->get_pointer_operand());
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b]; }
);
}
order_[i] = order;
}
// matrix multiplication optimizations
for(size_t i = 0; i < num_groups; i++){
std::vector<ir::dot_inst*> dots;
for(ir::value* v: layout_->values_of(i))
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
dots.push_back(x);
for(ir::dot_inst* dot: dots){
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
if(hmma_.at(layout_->layout_of(dot)) == HMMA_C){
auto a_val = layout_->values_of(layout_->layout_of(a));
auto b_val = layout_->values_of(layout_->layout_of(b));
for(ir::value *v: a_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
order_[layout_->layout_of(a)] = order_[layout_->layout_of(cts->get_operand(0))];
for(ir::value *v: b_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
order_[layout_->layout_of(b)] = order_[layout_->layout_of(cts->get_operand(0))];
}
else{
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
order_[layout_->layout_of(a)] = is_trans(a) ? row : col;
order_[layout_->layout_of(b)] = is_trans(b) ? col : row;
}
}
}
// tiling parameters
for(auto x: largest_){
ir::value *i = x.second;
if(!i->get_type()->is_tile_ty())
continue;
for(auto x: layout_->get_all()){
/* HMMA parameters*/
if(hmma_[x.first] == HMMA_C)
init_hmma_tile(i);
if(x.second.type == HMMA_884)
init_hmma_tile(x.second);
else
init_scanline_tile(i);
init_scanline_tile(x.second);
}
}

View File

@@ -577,37 +577,36 @@ inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
auto order = tiles_->order(v);
const auto& shapes = v->get_type()->get_tile_shapes();
void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
auto order = layout.order;
const auto& shapes = layout.shapes;
size_t dim = shapes.size();
std::vector<unsigned> contiguous(dim);
std::vector<unsigned> block_size(dim);
std::vector<unsigned> nts(dim);
std::vector<unsigned> mts(dim);
for(unsigned i = 0; i < shapes.size(); i++){
contiguous[i] = tiles_->nts(v, i);
block_size[i] = tiles_->mts(v, i);
nts[i] = tiles_->nts(layout.i, i);
mts[i] = tiles_->mts(layout.i, i);
}
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, block_size, builder);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, builder);
// Create axes
for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k);
Value *contiguous_k = builder.getInt32(contiguous[k]);
Value *contiguous_k = builder.getInt32(nts[k]);
Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k);
unsigned per_block = contiguous[k] * block_size[k];
unsigned per_thread = contiguous[k] * shapes[k] / per_block;
unsigned per_block = nts[k] * mts[k];
unsigned per_thread = nts[k] * shapes[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
unsigned offset = n / nts[k] * per_block + n % nts[k];
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[a_axes_->get(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id[k]};
axes_[layout.axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]};
}
}
void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
// auto order = reorder_->get_order(v);
const auto& shapes = v->get_type()->get_tile_shapes();
void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
const auto& shapes = layout.shapes;
if(shapes.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3;
@@ -619,13 +618,13 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
Value *_16 = builder.getInt32(16);
// fragments per warp
unsigned fpw_0 = tiles_->fpw(v, 0);
unsigned fpw_1 = tiles_->fpw(v, 1);
unsigned fpw_2 = is_batched ? tiles_->fpw(v, 2) : 1;
unsigned fpw_0 = tiles_->fpw(layout.i, 0);
unsigned fpw_1 = tiles_->fpw(layout.i, 1);
unsigned fpw_2 = is_batched ? tiles_->fpw(layout.i, 2) : 1;
// warps per tile
unsigned wpt_0 = tiles_->wpt(v, 0);
unsigned wpt_1 = tiles_->wpt(v, 1);
unsigned wpt_2 = is_batched ? tiles_->wpt(v, 2) : 1;
unsigned wpt_0 = tiles_->wpt(layout.i, 0);
unsigned wpt_1 = tiles_->wpt(layout.i, 1);
unsigned wpt_2 = is_batched ? tiles_->wpt(layout.i, 2) : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
@@ -706,18 +705,18 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
/* axes */
axes_[a_axes_->get(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[a_axes_->get(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
axes_[layout.axes[0]] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout.axes[1]] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[a_axes_->get(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
axes_[layout.axes[2]] = distributed_axis{1, idx_z, warp_id_2};
}
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
if(tiles_->hmma(v) == analysis::HMMA_C)
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
void selection::init_axes(const analysis::layout_t& layout, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
if(layout.type == analysis::HMMA_884)
init_hmma_axes(layout, builder, u_thread_id, u_warp_id);
else
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
init_strided_scan_axes(layout, builder, u_thread_id, u_warp_id);
}
/* -------------------
@@ -727,7 +726,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh_mem_ptr) {
if(tmap_.find(v) != tmap_.end())
return;
auto order = tiles_->order(v);
auto order = layouts_->get(v).order;
auto shapes = v->get_type()->get_tile_shapes();
unsigned pad = liveness_->get_pad(v);
if(pad > 0)
@@ -777,7 +776,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
axes[d].values = {builder.getInt32(0)};
}
}
distributed_tile *T = new distributed_tile(ty, shapes, tiles_->order(v), axes, builder, false);
distributed_tile *T = new distributed_tile(ty, shapes, layouts_->get(v).order, axes, builder, false);
bool is_inserted = tmap_.insert({v, T}).second;
// constant range
if(is_inserted && dynamic_cast<ir::make_range*>(v)){
@@ -820,7 +819,7 @@ void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_m
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid
for(auto x: tiles_->largest())
for(auto x: layouts_->get_all())
init_axes(x.second, builder, u_thread_warp_id, u_warp_id);
// create tile
std::set<ir::value*> seen;
@@ -868,7 +867,7 @@ void selection::lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, F
void selection::lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *scalars = tmap_.at(x->get_value_operand());
// size_t ld = tiles_->order(x->get_pointer_operand())[0];
// size_t ld = layouts_->order(x->get_pointer_operand())[0];
// unsigned vector_size = 2;
// // vectorize pointers
// std::map<unsigned, Value*> ptr_packets;
@@ -1015,9 +1014,9 @@ void selection::lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Functio
void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
unsigned vector_size = 1;
auto x_order = tiles_->order(x);
auto x_order = layouts_->get(x).order;
ir::value *arg = x->get_operand(0);
auto arg_order = tiles_->order(arg);
auto arg_order = layouts_->get(arg).order;
// tiles
shared_tile* result = (shared_tile*)tmap_.at(x);
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
@@ -1092,8 +1091,8 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
Value* u_thread_id = tgt_->get_local_id(builder.GetInsertBlock()->getModule(), builder, 0);
auto ord_a = tiles_->order(dot->get_operand(0));
auto ord_b = tiles_->order(dot->get_operand(1));
auto ord_a = layouts_->get(dot->get_operand(0)).order;
auto ord_b = layouts_->get(dot->get_operand(1)).order;
bool is_a_trans = is_trans(dot->get_operand(0));
bool is_b_trans = is_trans(dot->get_operand(1));
@@ -1255,7 +1254,7 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(tiles_->hmma(dot) == analysis::HMMA_C)
if(layouts_->get(dot).type == analysis::HMMA_884)
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
else
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
@@ -1271,7 +1270,7 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
// find vector size
distributed_tile* result = (distributed_tile*)tmap_.at(x);
ir::value *ptr = x->get_pointer_operand();
size_t ld = tiles_->order(ptr)[0];
size_t ld = layouts_->get(ptr).order[0];
unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
@@ -1343,7 +1342,7 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB
distributed_tile* result = (distributed_tile*)tmap_.at(x);
// find vector size
ir::value *ptr = x->get_pointer_operand();
size_t ld = tiles_->order(ptr)[0];
size_t ld = layouts_->get(ptr).order[0];
unsigned alignment = alignment_->get(ptr, ld);
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);

View File

@@ -83,8 +83,7 @@ void coalesce::run(ir::module &mod) {
if(axes.empty())
continue;
for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
remat.insert(remat.begin(),
it->second.begin(), it->second.end());
remat.insert(remat.begin(), it->second.begin(), it->second.end());
}
// rematerialize values
for(ir::io_inst *r: remat) {

View File

@@ -202,9 +202,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
// create passes
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::analysis::layout layouts(&axes);
codegen::analysis::layout layouts(&axes, &align);
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
codegen::analysis::liveness liveness(&tiles);
codegen::analysis::liveness liveness(&tiles, &layouts);
codegen::analysis::allocation allocation(&liveness, &tiles);
codegen::transform::membar barriers(&liveness, &allocation);
codegen::transform::dce dce;