[codegen] more cleaning
This commit is contained in:
@@ -24,8 +24,8 @@ class cts;
|
||||
|
||||
class allocation {
|
||||
public:
|
||||
allocation(liveness *live, tiles *params)
|
||||
: liveness_(live), tiles_(params){ }
|
||||
allocation(liveness *live)
|
||||
: liveness_(live) { }
|
||||
// accessors
|
||||
bool has_offset(ir::value *x) const { return offsets_.find(x) != offsets_.end(); }
|
||||
unsigned offset(ir::value *x) const { return offsets_.at(x); }
|
||||
@@ -38,7 +38,6 @@ private:
|
||||
size_t allocated_size_;
|
||||
// dependences
|
||||
liveness *liveness_;
|
||||
tiles *tiles_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -28,10 +28,13 @@ enum layout_type_t {
|
||||
|
||||
struct layout_t {
|
||||
layout_type_t type;
|
||||
ir::value *i;
|
||||
std::vector<int> axes;
|
||||
std::vector<unsigned> shapes;
|
||||
std::vector<int> order;
|
||||
std::map<int, int> mts;
|
||||
std::map<int, int> nts;
|
||||
std::map<int, int> fpw;
|
||||
std::map<int, int> wpt;
|
||||
};
|
||||
|
||||
class layout {
|
||||
@@ -43,15 +46,18 @@ private:
|
||||
void connect(ir::value *x, ir::value *y);
|
||||
void make_graph(ir::instruction *i);
|
||||
|
||||
void init_hmma_tile(layout_t& layout);
|
||||
void init_scanline_tile(layout_t &layout);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layout(analysis::axes *axes, analysis::align *align);
|
||||
layout(analysis::axes *axes, analysis::align *align, size_t num_warps);
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const;
|
||||
const std::vector<ir::value*>& values_of(unsigned id) const;
|
||||
size_t num_layouts() const;
|
||||
layout_t get(ir::value *v) const;
|
||||
const std::map<size_t, layout_t>& get_all() const;
|
||||
const layout_t& get(ir::value *v) const;
|
||||
std::map<size_t, layout_t> &get_all();
|
||||
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
@@ -59,6 +65,7 @@ public:
|
||||
private:
|
||||
analysis::axes* axes_;
|
||||
analysis::align* align_;
|
||||
size_t num_warps_;
|
||||
tools::graph<ir::value*> graph_;
|
||||
std::map<ir::value*, size_t> groups_;
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
|
@@ -73,7 +73,7 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
liveness(tiles *t, layout *l): tiles_(t), layouts_(l){ }
|
||||
liveness(layout *l): layouts_(l){ }
|
||||
// padding
|
||||
unsigned get_pad(ir::value *v) const { return pad_.at(v); }
|
||||
// buffer size
|
||||
@@ -92,7 +92,6 @@ public:
|
||||
|
||||
private:
|
||||
// analysis
|
||||
tiles *tiles_;
|
||||
layout *layouts_;
|
||||
// stuff
|
||||
has_storage_map_t has_dedicated_storage_;
|
||||
|
@@ -1,65 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_TILES_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_TILES_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class module;
|
||||
class instruction;
|
||||
class function;
|
||||
class metaparameter;
|
||||
class constant_int;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class axes;
|
||||
class layout;
|
||||
class align;
|
||||
|
||||
|
||||
class tiles {
|
||||
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
|
||||
private:
|
||||
void init_hmma_tile(const layout_t& layout);
|
||||
void init_scanline_tile(const layout_t& layout);
|
||||
bool is_trans(ir::value *i);
|
||||
|
||||
public:
|
||||
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
|
||||
void run(ir::module &mod);
|
||||
int mts(ir::value *value, unsigned ax);
|
||||
int nts(ir::value *value, unsigned ax);
|
||||
int fpw(ir::value *value, unsigned ax);
|
||||
int wpt(ir::value *value, unsigned ax);
|
||||
|
||||
|
||||
private:
|
||||
// dependencies
|
||||
analysis::align* align_;
|
||||
analysis::layout* layout_;
|
||||
analysis::axes* axes_;
|
||||
// number of warps
|
||||
size_t num_warps_;
|
||||
// tile properties
|
||||
std::map<int, int> fpw_;
|
||||
std::map<int, int> wpt_;
|
||||
std::map<int, int> mts_;
|
||||
std::map<int, int> nts_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -210,10 +210,10 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
selection(analysis::liveness* liveness, analysis::allocation *alloc, analysis::tiles *tiles,
|
||||
selection(analysis::liveness* liveness, analysis::allocation *alloc,
|
||||
analysis::align *alignment, analysis::axes *axes,
|
||||
analysis::layout *layouts, target *tgt, unsigned num_warps)
|
||||
: liveness_(liveness), alloc_(alloc), tiles_(tiles),
|
||||
: liveness_(liveness), alloc_(alloc),
|
||||
alignment_(alignment), a_axes_(axes), layouts_(layouts),
|
||||
tgt_(tgt), num_warps_(num_warps){ }
|
||||
|
||||
@@ -224,7 +224,6 @@ private:
|
||||
tmap_t tmap_;
|
||||
analysis::liveness *liveness_;
|
||||
analysis::allocation *alloc_;
|
||||
analysis::tiles *tiles_;
|
||||
analysis::axes *a_axes_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::align *alignment_;
|
||||
|
@@ -11,15 +11,6 @@
|
||||
// codegen
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
|
||||
|
@@ -38,7 +38,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
double total_time = 0;
|
||||
op();
|
||||
stream->synchronize();
|
||||
while(total_time*1e-9 < 1e-2){
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||
|
@@ -3,7 +3,6 @@
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/value.h"
|
||||
|
@@ -14,8 +14,8 @@ namespace analysis{
|
||||
|
||||
|
||||
// constructor
|
||||
layout::layout(analysis::axes *axes, analysis::align *align)
|
||||
: axes_(axes), align_(align) { }
|
||||
layout::layout(analysis::axes *axes, analysis::align *align, size_t num_warps)
|
||||
: axes_(axes), align_(align), num_warps_(num_warps) { }
|
||||
|
||||
// get group id
|
||||
unsigned layout::layout_of(ir::value *value) const
|
||||
@@ -72,19 +72,19 @@ bool is_hmma_c(ir::value *v){
|
||||
return result;
|
||||
}
|
||||
|
||||
layout_t layout::get(ir::value *v) const {
|
||||
const layout_t &layout::get(ir::value *v) const {
|
||||
return layouts_.at(groups_.at(v));
|
||||
}
|
||||
|
||||
const std::map<size_t, layout_t>& layout::get_all() const {
|
||||
std::map<size_t, layout_t>& layout::get_all() {
|
||||
return layouts_;
|
||||
}
|
||||
|
||||
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
||||
void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
if(i && i->get_pointer_operand() == v)
|
||||
result.insert(i);
|
||||
result.insert(v);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,6 +102,75 @@ inline bool is_trans(ir::value *v) {
|
||||
return false;
|
||||
}
|
||||
|
||||
inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
||||
return std::min(std::max(x, lo), hi);
|
||||
}
|
||||
|
||||
void layout::init_hmma_tile(layout_t& layout) {
|
||||
auto ord = layout.order;
|
||||
auto shapes = layout.shapes;
|
||||
unsigned shape_0 = shapes[ord[0]];
|
||||
unsigned shape_1 = shapes[ord[1]];
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> fpw = {1, 1, 1};
|
||||
std::vector<unsigned> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_0/8)*(shape_1/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw;
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
|
||||
}while(fpw_nm1 != fpw);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
layout.fpw[d] = fpw[d];
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> wpt = {1, 1, 1};
|
||||
std::vector<unsigned> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt;
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
|
||||
}while(wpt_nm1 != wpt);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
layout.wpt[d] = wpt[d];
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_warps *= layout.wpt[d];
|
||||
if(num_warps_ != effective_num_warps)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
void layout::init_scanline_tile(layout_t& layout) {
|
||||
auto ord = layout.order;
|
||||
auto shapes = layout.shapes;
|
||||
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
|
||||
unsigned ld = ord[0];
|
||||
unsigned num_threads = num_warps_*32;
|
||||
unsigned current = num_threads;
|
||||
layout.nts[ld] = clamp(size / num_threads, 1, 4);
|
||||
layout.mts[ld] = clamp(current, 1, shapes[ld] / layout.nts[ld]);
|
||||
current = current / layout.mts[ld];
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
ld = ord[d];
|
||||
layout.nts[ld] = 1;
|
||||
layout.mts[ld] = clamp(current, 1, shapes[ld]);
|
||||
current = current / layout.mts[ld];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_threads *= layout.mts[d];
|
||||
if(num_threads != effective_num_threads)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
void layout::run(ir::module &mod) {
|
||||
// make graph
|
||||
@@ -114,8 +183,8 @@ void layout::run(ir::module &mod) {
|
||||
// create layouts
|
||||
for(const auto& x: values_) {
|
||||
bool hmma_c = std::any_of(x.second.begin(), x.second.end(), &is_hmma_c);
|
||||
// type
|
||||
layouts_[x.first].type = hmma_c ? HMMA_884 : SCANLINE;
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -130,35 +199,32 @@ void layout::run(ir::module &mod) {
|
||||
return ret;
|
||||
};
|
||||
|
||||
// find out which value is the largest in each group
|
||||
// find out axes for each layout
|
||||
for(const auto& x: values_) {
|
||||
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
|
||||
ir::value *largest = *std::max_element(x.second.begin(), x.second.end(), cmp);
|
||||
layouts_[x.first].axes = axes_->get(largest);
|
||||
layouts_[x.first].i = largest;
|
||||
layouts_[x.first].shapes = largest->get_type()->get_tile_shapes();
|
||||
}
|
||||
|
||||
|
||||
// find out the layout ordering of a group
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::set<ir::io_inst*> io;
|
||||
for(ir::value* v: values_of(i))
|
||||
extract_io_use(v, io);
|
||||
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
|
||||
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
|
||||
};
|
||||
auto it = std::max_element(io.begin(), io.end(), cmp);
|
||||
std::vector<int> order(layouts_[i].axes.size());
|
||||
for(const auto& x: values_) {
|
||||
std::set<ir::value*> ptr;
|
||||
for(ir::value* v: x.second)
|
||||
extract_io_use(v, ptr);
|
||||
size_t rank = layouts_[x.first].axes.size();
|
||||
std::vector<int> order(rank);
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
if(it != io.end()) {
|
||||
auto max_contiguous = align_->contiguous((*it)->get_pointer_operand());
|
||||
for(ir::value *v: ptr){
|
||||
auto max_contiguous = align_->contiguous(v);
|
||||
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) {
|
||||
return max_contiguous[a] > max_contiguous[b]; }
|
||||
);
|
||||
}
|
||||
layouts_[i].order = order;
|
||||
layouts_[x.first].order = order;
|
||||
}
|
||||
|
||||
// matrix multiplication optimizations
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::vector<ir::dot_inst*> dots;
|
||||
@@ -187,6 +253,14 @@ void layout::run(ir::module &mod) {
|
||||
}
|
||||
}
|
||||
|
||||
// tiling parameters
|
||||
for(auto& x: layouts_){
|
||||
/* HMMA parameters*/
|
||||
if(x.second.type == HMMA_884)
|
||||
init_hmma_tile(x.second);
|
||||
else
|
||||
init_scanline_tile(x.second);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -3,7 +3,6 @@
|
||||
#include <unordered_set>
|
||||
#include "triton/codegen/instructions.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -146,9 +145,9 @@ unsigned liveness::num_bytes(ir::value *x) {
|
||||
num_elements *= x;
|
||||
size_t depth;
|
||||
if(layouts_->get(x).type == HMMA_884)
|
||||
depth = tiles_->wpt(op, axis);
|
||||
depth = layouts_->get(op).wpt.at(axis);
|
||||
else
|
||||
depth = tiles_->mts(op, axis);
|
||||
depth = layouts_->get(op).mts.at(axis);
|
||||
return num_elements * num_bytes * depth;
|
||||
}
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
|
@@ -1,130 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <numeric>
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/driver/device.h"
|
||||
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, analysis::layout *layout):
|
||||
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
|
||||
{ }
|
||||
|
||||
|
||||
|
||||
int tiles::mts(ir::value *value, unsigned ax) {
|
||||
return mts_.at(axes_->get(value, ax));
|
||||
}
|
||||
|
||||
int tiles::nts(ir::value *value, unsigned ax) {
|
||||
return nts_.at(axes_->get(value, ax));
|
||||
}
|
||||
|
||||
int tiles::fpw(ir::value *value, unsigned ax) {
|
||||
return fpw_.at(axes_->get(value, ax));
|
||||
}
|
||||
|
||||
int tiles::wpt(ir::value *value, unsigned ax) {
|
||||
return wpt_.at(axes_->get(value, ax));
|
||||
}
|
||||
|
||||
|
||||
unsigned clamp(unsigned x, unsigned lo, unsigned hi) {
|
||||
return std::min(std::max(x, lo), hi);
|
||||
}
|
||||
|
||||
|
||||
void tiles::init_hmma_tile(const layout_t& layout) {
|
||||
auto ord = layout.order;
|
||||
auto shapes = layout.i->get_type()->get_tile_shapes();
|
||||
unsigned shape_0 = shapes[ord[0]];
|
||||
unsigned shape_1 = shapes[ord[1]];
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> fpw = {1, 1, 1};
|
||||
std::vector<unsigned> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_0/8)*(shape_1/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw;
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[0] = clamp(fpw[0]*2, 1, shape_0 / 8);
|
||||
if(fpw[0]*fpw[1] < num_fragments)
|
||||
fpw[1] = clamp(fpw[1]*2, 1, shape_1 / 8);
|
||||
}while(fpw_nm1 != fpw);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
fpw_[layout.axes[d]] = fpw[d];
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> wpt = {1, 1, 1};
|
||||
std::vector<unsigned> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt;
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
|
||||
}while(wpt_nm1 != wpt);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
wpt_[layout.axes[d]] = wpt[d];
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_warps *= wpt_[layout.axes[d]];
|
||||
if(num_warps_ != effective_num_warps)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
void tiles::init_scanline_tile(const layout_t& layout) {
|
||||
auto ord = layout.order;
|
||||
auto shapes = layout.shapes;
|
||||
unsigned size = std::accumulate(shapes.begin(), shapes.end(), 1, std::multiplies<int>());
|
||||
unsigned ld = ord[0];
|
||||
unsigned num_threads = num_warps_*32;
|
||||
unsigned current = num_threads;
|
||||
nts_[layout.axes[ld]] = clamp(size / num_threads, 1, 4);
|
||||
mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld] / nts_[layout.axes[ld]]);
|
||||
current = current / mts_[layout.axes[ld]];
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
ld = ord[d];
|
||||
nts_[layout.axes[ld]] = 1;
|
||||
mts_[layout.axes[ld]] = clamp(current, 1, shapes[ld]);
|
||||
current = current / mts_[layout.axes[ld]];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_threads *= mts_[layout.axes[d]];
|
||||
// std::cout << num_threads << " " << effective_num_threads << std::endl;
|
||||
if(num_threads != effective_num_threads)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
void tiles::run(ir::module &) {
|
||||
// tiling parameters
|
||||
for(auto x: layout_->get_all()){
|
||||
/* HMMA parameters*/
|
||||
if(x.second.type == HMMA_884)
|
||||
init_hmma_tile(x.second);
|
||||
else
|
||||
init_scanline_tile(x.second);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -4,7 +4,6 @@
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
@@ -584,8 +583,8 @@ void selection::init_strided_scan_axes(const analysis::layout_t& layout, IRBuild
|
||||
std::vector<unsigned> nts(dim);
|
||||
std::vector<unsigned> mts(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
nts[i] = tiles_->nts(layout.i, i);
|
||||
mts[i] = tiles_->mts(layout.i, i);
|
||||
nts[i] = layout.nts.at(i);
|
||||
mts[i] = layout.mts.at(i);
|
||||
}
|
||||
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
|
||||
std::vector<Value*> thread_id = delinearize(full_thread_id, order, mts, builder);
|
||||
@@ -618,13 +617,13 @@ void selection::init_hmma_axes(const analysis::layout_t& layout, IRBuilder<> &bu
|
||||
Value *_16 = builder.getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = tiles_->fpw(layout.i, 0);
|
||||
unsigned fpw_1 = tiles_->fpw(layout.i, 1);
|
||||
unsigned fpw_2 = is_batched ? tiles_->fpw(layout.i, 2) : 1;
|
||||
unsigned fpw_0 = layout.fpw.at(0);
|
||||
unsigned fpw_1 = layout.fpw.at(1);
|
||||
unsigned fpw_2 = is_batched ? layout.fpw.at(2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = tiles_->wpt(layout.i, 0);
|
||||
unsigned wpt_1 = tiles_->wpt(layout.i, 1);
|
||||
unsigned wpt_2 = is_batched ? tiles_->wpt(layout.i, 2) : 1;
|
||||
unsigned wpt_0 = layout.wpt.at(0);
|
||||
unsigned wpt_1 = layout.wpt.at(1);
|
||||
unsigned wpt_2 = is_batched ? layout.wpt.at(2) : 1;
|
||||
// hmma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
@@ -933,7 +932,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
tgt_->add_barrier(module, builder);
|
||||
builder.CreateStore(result, write_ptr);
|
||||
// build result
|
||||
unsigned depth = tiles_->wpt(op, axis);
|
||||
unsigned depth = layouts_->get(op).wpt.at(axis);
|
||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||
// current indices
|
||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||
@@ -1022,7 +1021,7 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct
|
||||
distributed_tile* in = (distributed_tile*)tmap_.at(arg);
|
||||
if(x_order == arg_order){
|
||||
size_t ld = arg_order[0];
|
||||
vector_size = std::min(tiles_->nts(x, ld),tiles_->nts(arg, ld));
|
||||
vector_size = std::min(layouts_->get(x).nts.at(ld), layouts_->get(arg).nts.at(ld));
|
||||
}
|
||||
|
||||
std::map<unsigned, Value*> packets;
|
||||
@@ -1118,12 +1117,12 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn
|
||||
"{$10, $11}, "
|
||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||
|
||||
unsigned fpw_0 = tiles_->fpw(dot, 0);
|
||||
unsigned fpw_1 = tiles_->fpw(dot, 1);
|
||||
unsigned fpw_0 = layouts_->get(dot).fpw.at(0);
|
||||
unsigned fpw_1 = layouts_->get(dot).fpw.at(1);
|
||||
unsigned wts_0 = fpw_0 * 8;
|
||||
unsigned wts_1 = fpw_1 * 8;
|
||||
unsigned wpt_0 = tiles_->wpt(dot, 0);
|
||||
unsigned wpt_1 = tiles_->wpt(dot, 1);
|
||||
unsigned wpt_0 = layouts_->get(dot).wpt.at(0);
|
||||
unsigned wpt_1 = layouts_->get(dot).wpt.at(1);
|
||||
unsigned stride_rep_i = wpt_0 * wts_0;
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
unsigned num_rep_i = shapes[0] / stride_rep_i;
|
||||
|
@@ -4,11 +4,17 @@
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
@@ -202,17 +208,16 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
// create passes
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::analysis::layout layouts(&axes, &align);
|
||||
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
|
||||
codegen::analysis::liveness liveness(&tiles, &layouts);
|
||||
codegen::analysis::allocation allocation(&liveness, &tiles);
|
||||
codegen::analysis::layout layouts(&axes, &align, opt.num_warps);
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::membar barriers(&liveness, &allocation);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&align);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::cts cts;
|
||||
codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, target.get(), opt.num_warps);
|
||||
codegen::selection selection(&liveness, &allocation, &align, &axes, &layouts, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
@@ -226,24 +231,20 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
dce.run(module);
|
||||
reassociate.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
// exit(EXIT_FAILURE);
|
||||
dce.run(module);
|
||||
cts.run(module);
|
||||
align.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
tiles.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run(module);
|
||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
return std::unique_ptr<driver::module>();
|
||||
barriers.run(module);
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
align.run(module);
|
||||
tiles.run(module);
|
||||
selection.run(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
|
Reference in New Issue
Block a user