[codegen] more progress towards unified dot implementation

This commit is contained in:
Philippe Tillet
2019-09-26 14:01:28 -04:00
parent 69800a0318
commit 575dd06be3
12 changed files with 227 additions and 243 deletions

View File

@@ -4,6 +4,7 @@
#include <map>
#include <set>
#include <iostream>
#include "triton/codegen/analysis/liveness.h"
namespace triton{
@@ -25,10 +26,8 @@ class allocation {
public:
allocation(liveness *live, tiles *params)
: liveness_(live), tiles_(params){ }
// utilities
unsigned num_bytes(ir::value *x);
unsigned is_ld_padded(ir::value* x);
// accessors
bool has_offset(ir::value *x) const { return offsets_.find(x) != offsets_.end(); }
unsigned offset(ir::value *x) const { return offsets_.at(x); }
unsigned allocated_size() const { return allocated_size_; }
// run
@@ -36,7 +35,6 @@ public:
private:
std::map<ir::value*, unsigned> offsets_;
std::map<ir::value*, unsigned> num_bytes_;
size_t allocated_size_;
// dependences
liveness *liveness_;

View File

@@ -2,6 +2,8 @@
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#include <map>
#include <set>
#include <vector>
namespace triton{
@@ -10,6 +12,7 @@ namespace ir{
class phi_node;
class function;
class module;
class instruction;
}
namespace codegen{
@@ -17,7 +20,7 @@ namespace analysis{
typedef unsigned slot_index;
class cts;
class tiles;
struct segment {
slot_index start;
@@ -37,21 +40,47 @@ struct double_buffer_info_t {
ir::phi_node* phi;
};
struct buffer_t {
unsigned id;
size_t size;
bool operator<(buffer_t other) const { return id < other.id; }
};
class liveness {
private:
typedef std::map<ir::value*, slot_index> indices_map_t;
typedef std::map<ir::value*, segment> intervals_map_t;
typedef std::map<buffer_t, segment> intervals_map_t;
typedef std::map<ir::value*, bool> has_storage_map_t;
typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
public:
// Intervals iterators
using iterator = intervals_map_t::iterator;
using const_iterator = intervals_map_t::const_iterator;
private:
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
void extract_double_bufferable(ir::instruction *i);
void extract_buffers(ir::instruction *i);
void get_parents(ir::instruction *i, std::vector<ir::value *>& res);
void make_graph(ir::instruction *i);
public:
liveness(tiles *t): tiles_(t){ }
// buffer size
unsigned is_ld_padded(ir::value *x);
unsigned num_bytes(ir::value *x);
// accessors
const intervals_map_t& intervals() const { return intervals_; }
segment get_interval(ir::value* v) const { return intervals_.at(v); }
const intervals_map_t& intervals() const { return intervals_; }
segment get_interval(buffer_t v) const { return intervals_.at(v); }
// buffers
buffer_t get_buffer(ir::value *v) const { return groups_.at(v); }
std::vector<ir::value*> get_values(buffer_t x) const { return values_.at(x); }
// double-buffering
bool has_double(ir::value *x) const { return double_.find(x) != double_.end(); }
double_buffer_info_t get_double(ir::value *x) const { return double_.at(x); }
@@ -59,10 +88,19 @@ public:
void run(ir::module &mod);
private:
// analysis
tiles *tiles_;
// stuff
has_storage_map_t has_dedicated_storage_;
indices_map_t indices_;
indices_map_t indices;
intervals_map_t intervals_;
std::map<ir::value*, double_buffer_info_t> double_;
std::map<ir::value*, std::vector<ir::value*>> parents_;
// graph
std::set<node_t> nodes_;
graph_t graph_;
std::map<ir::value*, buffer_t> groups_;
std::map<buffer_t, std::vector<ir::value*>> values_;
};
}

View File

@@ -56,7 +56,7 @@ static const std::map<ir::value_id_t, inst_storage_info_t> storage_info = {
{ ir::INST_BROADCAST, {DISTRIBUTED, {REPLICATED}}},
{ ir::INST_DOWNCAST, {DISTRIBUTED, {REPLICATED}}},
// array arithmetic
{ ir::INST_TRANS, {SHARED, {DISTRIBUTED}}}, // TODO: not necessarily
{ ir::INST_TRANS, {SHARED, {SHARED}}},
{ ir::INST_REDUCE, {SHARED, {DISTRIBUTED}}},
{ ir::INST_DOT, {DISTRIBUTED, {SHARED, SHARED, DISTRIBUTED}}},
// terminator

View File

@@ -15,79 +15,28 @@ namespace triton{
namespace codegen{
namespace analysis{
unsigned allocation::is_ld_padded(ir::value *x) {
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
if(trans->get_perm()[0]->get_value() != 0)
return 4;
}
auto order = tiles_->order(x);
bool is_col_major = order[0] == 0;
if(tiles_->hmma(x) == HMMA_A_ROW)
return is_col_major ? 16 : 8;
if(tiles_->hmma(x) == HMMA_A_COL)
return is_col_major ? 8 : 16;
if(tiles_->hmma(x) == HMMA_B_COL)
return is_col_major ? 16 : 8;
if(tiles_->hmma(x) == HMMA_B_ROW)
return is_col_major ? 8 : 16;
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
unsigned result = 0;
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
result = std::max(result, is_ld_padded(phi->get_incoming_value(i)));
return result;
}
return 0;
}
unsigned allocation::num_bytes(ir::value *x) {
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t axis = red->get_axis();
ir::value *op = red->get_operand(0);
auto shapes = op->get_type()->get_tile_shapes();
shapes.erase(shapes.begin() + axis);
size_t num_elements = 1;
for(auto x: shapes)
num_elements *= x;
size_t depth;
if(tiles_->hmma(x))
depth = tiles_->wpt(op, axis);
else
depth = tiles_->mts(op, axis);
return num_elements * num_bytes * depth;
}
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = is_ld_padded(x);
if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
num_bytes += pad * num_bytes / ld;
}
if(liveness_->has_double(x))
num_bytes *= 2;
return num_bytes;
}
void allocation::run(ir::module &mod) {
using std::max;
using std::min;
typedef std::multimap<unsigned, segment> triples_map_type;
std::vector<ir::value *> I;
std::vector<buffer_t> I;
for(auto x: liveness_->intervals())
I.push_back(x.first);
std::vector<ir::value *> J = I;
std::vector<buffer_t> J = I;
triples_map_type H;
H.insert({0, segment{0, INT_MAX}});
std::vector<ir::value *> V;
std::map<ir::value *, unsigned> starts;
std::vector<buffer_t> V;
std::map<buffer_t, unsigned> starts;
while(!J.empty()){
auto h_it = H.begin();
unsigned w = h_it->first;
segment xh = h_it->second;
H.erase(h_it);
auto j_it = std::find_if(J.begin(), J.end(), [&](ir::value *JJ){
auto j_it = std::find_if(J.begin(), J.end(), [&](buffer_t JJ){
segment xj = liveness_->get_interval(JJ);
bool res = xj.intersect(xh);
for(auto val: H)
@@ -95,7 +44,7 @@ void allocation::run(ir::module &mod) {
return res;
});
if(j_it != J.end()){
unsigned size = num_bytes(*j_it);
unsigned size = j_it->size;
segment xj = liveness_->get_interval(*j_it);
starts[*j_it] = w;
H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}});
@@ -109,14 +58,14 @@ void allocation::run(ir::module &mod) {
}
// Build interference graph
std::map<ir::value*, std::set<ir::value *>> interferences;
for(ir::value *x: V)
for(ir::value *y: V){
if(x == y)
std::map<buffer_t, std::set<buffer_t>> interferences;
for(buffer_t x: V)
for(buffer_t y: V){
if(x.id == y.id)
continue;
unsigned X0 = starts[x], Y0 = starts[y];
unsigned NX = num_bytes(x);
unsigned NY = num_bytes(y);
unsigned NX = x.size;
unsigned NY = y.size;
segment XS = {X0, X0 + NX};
segment YS = {Y0, Y0 + NY};
if(liveness_->get_interval(x).intersect(liveness_->get_interval(y))
@@ -125,17 +74,17 @@ void allocation::run(ir::module &mod) {
}
// Initialize colors
std::map<ir::value *, int> colors;
for(ir::value *X: V)
colors[X] = (X==V[0])?0:-1;
std::map<buffer_t, int> colors;
for(buffer_t X: V)
colors[X] = (X.id==V[0].id)?0:-1;
// First-fit graph coloring
std::vector<bool> available(V.size());
for(ir::value *x: V){
for(buffer_t x: V){
// Non-neighboring colors are available
std::fill(available.begin(), available.end(), true);
for(ir::value *Y: interferences[x]){
for(buffer_t Y: interferences[x]){
int color = colors[Y];
if(color >= 0)
available[color] = false;
@@ -146,21 +95,24 @@ void allocation::run(ir::module &mod) {
}
// Finalize allocation
for(ir::value *x: V){
for(buffer_t x: V){
unsigned Adj = 0;
for(ir::value *y: interferences[x])
Adj = std::max(Adj, starts[y] + num_bytes(y));
offsets_[x] = starts[x] + colors[x] * Adj;
if(liveness_->has_double(x)){
auto info = liveness_->get_double(x);
offsets_[info.latch] = offsets_[x] + num_bytes(x) / 2;
for(buffer_t y: interferences[x])
Adj = std::max<unsigned>(Adj, starts[y] + y.size);
// create offsets
for(ir::value *v: liveness_->get_values(x)){
offsets_[v] = starts[x] + colors[x] * Adj;
if(liveness_->has_double(v)){
auto info = liveness_->get_double(v);
offsets_[info.latch] = offsets_[v] + x.size / 2;
}
}
}
// Save maximum size of induced memory space
allocated_size_ = 0;
for(auto &x: offsets_){
allocated_size_ = std::max<size_t>(allocated_size_, x.second + num_bytes(x.first));
allocated_size_ = std::max<size_t>(allocated_size_, x.second + liveness_->get_buffer(x.first).size);
}
}

View File

@@ -1,6 +1,9 @@
#include <iostream>
#include <climits>
#include <unordered_set>
#include "triton/codegen/instructions.h"
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/tiles.h"
#include "triton/codegen/transform/cts.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/function.h"
@@ -25,7 +28,7 @@ inline bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
throw std::runtime_error("unreachable");
}
inline void extract_double_bufferable(ir::instruction *i, std::map<ir::value*, double_buffer_info_t>& result) {
void liveness::extract_double_bufferable(ir::instruction *i) {
auto* phi = dynamic_cast<ir::phi_node*>(i);
if(!phi || phi->get_num_incoming() != 2)
return;
@@ -42,65 +45,142 @@ inline void extract_double_bufferable(ir::instruction *i, std::map<ir::value*, d
if(!i_0 || !i_1 || storage_info.at(i_0->get_id()).first != SHARED || storage_info.at(i_1->get_id()).first != SHARED)
return;
if(is_latch_1)
result[value_0] = double_buffer_info_t{value_1, phi};
double_[value_0] = double_buffer_info_t{value_1, phi};
if(is_latch_0)
result[value_1] = double_buffer_info_t{value_0, phi};
double_[value_1] = double_buffer_info_t{value_0, phi};
}
void liveness::make_graph(ir::instruction *i) {
if(has_double(i)){
ir::value *latch = double_[i].latch;
nodes_.insert(i);
nodes_.insert(latch);
graph_[i].insert(latch);
graph_[latch].insert(i);
}
if(i->get_id() == ir::INST_TRANS){
nodes_.insert(i);
nodes_.insert(i->get_operand(0));
graph_[i].insert(i->get_operand(0));
graph_[i->get_operand(0)].insert(i);
}
}
// connected components
void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
buffer_t buffer{group_id, num_bytes(x)};
groups_[x] = buffer;
values_[buffer].push_back(x);
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
for(const node_t &y: graph[x])
connected_components(y, nodes, graph, group_id);
}
}
unsigned liveness::is_ld_padded(ir::value *x) {
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
if(trans->get_perm()[0]->get_value() != 0)
return 4;
}
auto order = tiles_->order(x);
bool is_col_major = order[0] == 0;
if(tiles_->hmma(x) == HMMA_A_ROW)
return is_col_major ? 16 : 16;
if(tiles_->hmma(x) == HMMA_A_COL)
return is_col_major ? 8 : 8;
if(tiles_->hmma(x) == HMMA_B_COL)
return is_col_major ? 16 : 16;
if(tiles_->hmma(x) == HMMA_B_ROW)
return is_col_major ? 8 : 8;
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
unsigned result = 0;
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
result = std::max(result, is_ld_padded(phi->get_incoming_value(i)));
return result;
}
return 0;
}
unsigned liveness::num_bytes(ir::value *x) {
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t axis = red->get_axis();
ir::value *op = red->get_operand(0);
auto shapes = op->get_type()->get_tile_shapes();
shapes.erase(shapes.begin() + axis);
size_t num_elements = 1;
for(auto x: shapes)
num_elements *= x;
size_t depth;
if(tiles_->hmma(x))
depth = tiles_->wpt(op, axis);
else
depth = tiles_->mts(op, axis);
return num_elements * num_bytes * depth;
}
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = is_ld_padded(x);
if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[tiles_->order(x)[0]];
num_bytes += pad * num_bytes / ld;
}
if(has_double(x))
num_bytes *= 2;
return num_bytes;
}
// Entry point
void liveness::run(ir::module &mod) {
double_.clear();
indices_.clear();
indices.clear();
intervals_.clear();
parents_.clear();
// set of pair of values that can be double-buffered
// Create set of pair of values that can be double-buffered
ir::for_each_instruction(mod, [this](ir::instruction* i) {
extract_double_bufferable(i, this->double_);
this->extract_double_bufferable(i);
});
// Create buffer dependency graph
ir::for_each_instruction(mod, [this](ir::instruction* i) {
this->make_graph(i);
});
// connected components
unsigned group_id = 0;
while(!nodes_.empty()){
connected_components(*nodes_.begin(), nodes_, graph_, group_id++);
}
// Assigns index to each instruction
for(ir::function *fn: mod.get_function_list()){
// Assigns index to each instruction
slot_index index = 0;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()){
index += 1;
indices_.insert({instr, index});
}
// Liveness analysis
// Creates live intervals
for(auto i: indices_){
ir::value *v = i.first;
ir::instruction* instr = dynamic_cast<ir::instruction*>(v);
if(!instr)
continue;
if(storage_info.at(instr->get_id()).first != SHARED)
continue;
unsigned start = i.second;
unsigned end = start;
for(ir::value *u: v->get_users()){
start = std::min(start, indices_.at(u));
end = std::max(end, indices_.at(u));
}
intervals_[v] = segment{start, end};
}
// Double-Buffering
// Arrays are live throughout the end of the loop
auto it = intervals_.begin();
while(it != intervals_.end()) {
ir::value *x = it->first;
auto dit = double_.find(x);
if(dit != double_.end()) {
ir::value *y = dit->second.latch;
unsigned start = intervals_[x].start;
unsigned end = intervals_[y].end;
intervals_[x] = segment{start, end};
intervals_.erase(y);
}
it++;
indices.insert({instr, index});
}
}
for(auto x: values_) {
// users
std::set<ir::value*> values;
for(ir::value *v: x.second){
values.insert(v);
for(ir::user *u: v->get_users())
values.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
unsigned end = 0;
for(ir::value *u: values){
start = std::min(start, indices.at(u));
end = std::max(end, indices.at(u));
}
intervals_[x.first] = segment{start, end};
}
}
}

View File

@@ -725,7 +725,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
return;
auto order = tiles_->order(v);
auto shapes = v->get_type()->get_tile_shapes();
unsigned pad = alloc_->is_ld_padded(v);
unsigned pad = liveness_->is_ld_padded(v);
if(pad > 0)
shapes[order[0]] += pad;
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), builder.getContext());
@@ -1040,15 +1040,13 @@ void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ct
}
void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
shared_tile* result = (shared_tile*)tmap_.at(x);
distributed_tile* in = (distributed_tile*)tmap_.at(x->get_operand(0));
auto perm = x->get_perm();
in->for_each([&](indices_t idx){
indices_t out_idx(idx.size());
for(size_t i = 0; i < idx.size(); i++)
out_idx[i] = idx[perm[i]->get_value()];
result->set_value(out_idx, in->get_value(idx));
});
shared_tile* in = (shared_tile*)tmap_.at(x->get_operand(0));
auto in_order = in->get_order();
std::vector<int> order;
for(auto p: x->get_perm())
order.push_back(in_order[p->get_value()]);
shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), order, in->get_pointer(), builder, in->get_offset());
tmap_[x] = out;
}
void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder,
@@ -1555,7 +1553,7 @@ void selection::run(ir::module &src, Module &dst) {
}
else {
unsigned num_bytes = phi->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
offset->addIncoming(dst_builder.getInt32(alloc_->num_bytes(phi)/(num_bytes)), llvm_inc_block);
offset->addIncoming(dst_builder.getInt32(liveness_->num_bytes(phi)/(num_bytes)), llvm_inc_block);
}
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
}

View File

@@ -36,9 +36,9 @@ void membar::add_reference(ir::value *v, interval_vec_t &res){
auto *i = dynamic_cast<ir::instruction*>(v);
if(!i)
return;
if(storage_info.at(i->get_id()).first == SHARED){
if(alloc_->has_offset(v)){
unsigned offset = alloc_->offset(v);
unsigned num_bytes = alloc_->num_bytes(v);
unsigned num_bytes = liveness_->num_bytes(v);
res.push_back(interval_t(offset, offset + num_bytes));
}
}
@@ -97,8 +97,10 @@ std::pair<membar::interval_vec_t,
bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written);
// double buffering: write and phi-node read won't intersect
if(safe_war.find(i) != safe_war.end())
if(safe_war.find(i) != safe_war.end()){
write_after_read = false;
read_after_write = false;
}
// record hazards
if(read_after_write || write_after_read) {
insert_loc.insert(i);
@@ -122,7 +124,12 @@ void membar::run(ir::module &mod) {
auto info = liveness_->get_double(i);
safe_war.insert(i);
safe_war.insert(info.latch);
auto *trans = dynamic_cast<ir::trans_inst*>(info.latch);
if(trans)
safe_war.insert(trans->get_operand(0));
}
if(i->get_id() == ir::INST_TRANS)
safe_war.insert(i);
});
for(ir::function *fn: mod.get_function_list()){
@@ -152,9 +159,8 @@ void membar::run(ir::module &mod) {
done = (n_inserted_im1 == n_inserted_i);
n_inserted_im1 = n_inserted_i;
}while(!done);
for(ir::instruction* i: insert_locs){
for(ir::instruction* i: insert_locs)
insert_barrier(i, builder);
}
}
}

View File

@@ -84,70 +84,6 @@ bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
return true;
}
bool peephole::rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b,
ir::value *A, ir::value *B, ir::value *D){
ir::value *AA = A;
ir::value *BB = B;
if(trans_a){
AA = ((ir::trans_inst*)A)->get_operand(0);
}
else{
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
std::vector<ir::constant_int*> perm(T->get_perm());
std::swap(perm[0], perm[1]);
AA = builder.create_trans(T->get_operand(0), perm);
T->replace_all_uses_with(AA);
trans_a = true;
}
}
if(trans_b){
BB = ((ir::trans_inst*)B)->get_operand(0);
}
else{
if(auto *T = dynamic_cast<ir::trans_inst*>(B)){
std::vector<ir::constant_int*> perm(T->get_perm());
std::swap(perm[0], perm[1]);
BB = builder.create_trans(T->get_operand(0), perm);
T->replace_all_uses_with(BB);
trans_b = true;
}
}
if(!trans_a && !trans_b)
return false;
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
dot->replace_all_uses_with(dot_atbt);
return true;
}
bool peephole::rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b,
ir::value *A, ir::value *B, ir::value *D){
// dot(op(a), trans(b))
if(trans_b){
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
dot->replace_all_uses_with(NT);
return true;
}
// dot(op(a), b)
if(!trans_b){
// create permutations
size_t size = B->get_type()->get_tile_shapes().size();
std::vector<ir::constant_int*> perm(size);
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
for(size_t i = 0; i < size; i++)
perm[i] = ir::constant_int::get(int32_ty, i);
std::swap(perm[0], perm[1]);
// replace NN -> NT (trans)
ir::value* BB = builder.create_trans(B, perm);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
dot->replace_all_uses_with(NT);
return true;
}
return false;
}
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
// dot(a, b, 0) + c -> dot(a, b, c)
auto add = dynamic_cast<ir::binary_operator*>(value);
@@ -176,26 +112,6 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
add->replace_all_uses_with(new_dot);
return true;
}
// dot(a, b, c)
auto dot = dynamic_cast<ir::dot_inst*>(value);
if(!dot)
return false;
builder.set_insert_point(value);
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
bool trans_a = is_trans(A);
bool trans_b = is_trans(B);
// only consider dot-nn
if(dot->is_a_trans() || dot->is_b_trans())
return false;
// hmma
if(is_hmma(dot)){
return rewrite_dot_hmma(dot, builder, trans_a, trans_b, A, B, D);
}
else
return rewrite_dot_fp32(dot, builder, trans_a, trans_b, A, B, D);
}
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){

View File

@@ -202,11 +202,11 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
// create passes
codegen::transform::cts cts;
codegen::analysis::align align;
codegen::analysis::liveness liveness;
codegen::analysis::axes axes;
codegen::analysis::layout layouts(&axes);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
codegen::analysis::liveness liveness(&tiles);
codegen::analysis::allocation allocation(&liveness, &tiles);
codegen::transform::membar barriers(&liveness, &allocation);
codegen::transform::dce dce;
@@ -235,12 +235,10 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
return std::unique_ptr<driver::module>();
barriers.run(module);
dce.run(module);
dce.run(module);
axes.run(module);
layouts.run(module);
align.run(module);
tiles.run(module);
// ir::print(module, std::cout);
selection.run(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));

View File

@@ -79,10 +79,8 @@ int main() {
// shapes to benchmark
typedef std::tuple<bool, bool, int, int, int> config_t;
std::vector<config_t> configs;
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
{false, true},
{true, false},
{true, true}}){
for(auto x: std::vector<std::array<bool, 2>>{{false, true},
{true, false}, {true, true}}){
std::vector<config_t> tmp = {
config_t{x[0], x[1], 4096, 4096, 4096}
// config_t{x[0], x[1], 16, 2048, 2048},

View File

@@ -4,15 +4,15 @@ namespace src {
R"(
#if AT == 1
#define USEA ^a
#define STRIDE_AK 1
#define STRIDE_AM lda
#define STRIDE_AK lda
#define STRIDE_AM 1
#define BROADCAST_AK :, newaxis
#define BROADCAST_AM newaxis, :
#define SHAPE_A TK, TM
#else
#define USEA a
#define STRIDE_AK lda
#define STRIDE_AM 1
#define STRIDE_AK 1
#define STRIDE_AM lda
#define BROADCAST_AK newaxis, :
#define BROADCAST_AM :, newaxis
#define SHAPE_A TM, TK
@@ -20,15 +20,15 @@ R"(
#if BT == 1
#define USEB ^b
#define STRIDE_BK ldb
#define STRIDE_BN 1
#define STRIDE_BK 1
#define STRIDE_BN ldb
#define BROADCAST_BK newaxis, :
#define BROADCAST_BN :, newaxis
#define SHAPE_B TN, TK
#else
#define USEB b
#define STRIDE_BK 1
#define STRIDE_BN ldb
#define STRIDE_BK ldb
#define STRIDE_BN 1
#define BROADCAST_BK :, newaxis
#define BROADCAST_BN newaxis, :
#define SHAPE_B TK, TN

View File

@@ -139,8 +139,8 @@ int main() {
// shapes to benchmark
typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t;
std::vector<config_t> configs;
for(bool AT: std::array<bool, 2>{false, true})
for(bool BT: std::array<bool, 2>{false, true})
for(bool AT: std::array<bool, 2>{false})
for(bool BT: std::array<bool, 2>{false})
for(int TM: std::vector<int>{32, 64})
for(int TN: std::vector<int>{32, 64})
for(int TK: std::vector<int>{16, 32})