[codegen] more cleaning

This commit is contained in:
Philippe Tillet
2019-10-07 18:06:54 -04:00
parent 1783d45bef
commit 650c43ca07
16 changed files with 111 additions and 191 deletions

View File

@@ -5,6 +5,7 @@
#include <set>
#include <vector>
#include <memory>
#include "triton/tools/graph.h"
namespace triton{
@@ -19,10 +20,8 @@ namespace analysis{
class axes {
typedef std::pair<ir::value*, unsigned> node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
private:
void add_constraint(node_t x, node_t y);
// update graph
void update_graph_store(ir::instruction *i);
void update_graph_reduce(ir::instruction *i);
@@ -32,21 +31,15 @@ private:
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i);
void update_graph(ir::instruction *i);
// connected components
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
public:
axes();
void run(ir::module &mod);
unsigned get_id(ir::value *value, unsigned ax);
bool has_id(ir::value *value, unsigned ax);
unsigned get_id(ir::value *value, unsigned dim);
private:
// constraints graph
graph_t dependencies_;
std::set<node_t> nodes_;
// parameter groups
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
tools::graph<node_t> graph_;
std::map<node_t, size_t> axes_;
};
}

View File

@@ -5,6 +5,7 @@
#include <set>
#include <vector>
#include <memory>
#include "triton/tools/graph.h"
namespace triton{
@@ -27,29 +28,24 @@ private:
// graph creation
void connect(ir::value *x, ir::value *y);
void make_graph(ir::instruction *i);
// connected components
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned id);
// list the axes of the given value
std::set<int> axes_of(ir::value *value);
public:
// constructor
layout(analysis::axes *axes);
// run the passes
// accessors
unsigned layout_of(ir::value *value) const;
const std::vector<ir::value*>& values_of(unsigned id) const;
size_t num_layouts() const;
// execution
void run(ir::module &mod);
// get the layout ID of the given value
unsigned id(ir::value *value) const;
// get the values associates with the given ID
const std::vector<ir::value*>& values(unsigned id) const;
// get number of groups
size_t get_num_groups() const;
private:
analysis::axes* axes_;
graph_t dependencies_;
std::set<node_t> nodes_;
std::map<ir::value*, unsigned> groups_;
std::map<unsigned, std::vector<ir::value*>> values_;
tools::graph<ir::value*> graph_;
std::map<ir::value*, size_t> groups_;
std::map<size_t, std::vector<ir::value*>> values_;
};
}

View File

@@ -4,6 +4,7 @@
#include <map>
#include <set>
#include <vector>
#include "triton/tools/graph.h"
namespace triton{
@@ -41,7 +42,7 @@ struct double_buffer_info_t {
};
struct buffer_t {
unsigned id;
size_t id;
size_t size;
bool operator<(buffer_t other) const { return id < other.id; }
};
@@ -63,7 +64,6 @@ public:
private:
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, buffer_t *buffer);
void extract_double_bufferable(ir::instruction *i);
void extract_buffers(ir::instruction *i);
void get_parents(ir::instruction *i, std::vector<ir::value *>& res);
@@ -98,11 +98,8 @@ private:
intervals_map_t intervals_;
std::map<ir::value*, double_buffer_info_t> double_;
std::map<ir::value*, size_t> pad_;
std::map<ir::value*, std::vector<ir::value*>> parents_;
// graph
std::set<node_t> nodes_;
graph_t graph_;
std::vector<buffer_t*> buffers_;
// buffers
tools::graph<node_t> graph_;
std::map<ir::value*, buffer_t*> groups_;
std::map<buffer_t*, std::vector<ir::value*>> values_;
};

View File

@@ -211,10 +211,10 @@ private:
public:
selection(analysis::liveness* liveness, analysis::allocation *alloc, analysis::tiles *tiles,
analysis::align *alignment, analysis::axes *axes,
analysis::layout *layouts, transform::coalesce* reorder, target *tgt, unsigned num_warps)
analysis::layout *layouts, target *tgt, unsigned num_warps)
: liveness_(liveness), alloc_(alloc), tiles_(tiles),
alignment_(alignment), a_axes_(axes), layouts_(layouts),
reorder_(reorder), tgt_(tgt), num_warps_(num_warps){ }
tgt_(tgt), num_warps_(num_warps){ }
void run(ir::module &src, Module &dst);
@@ -227,7 +227,6 @@ private:
analysis::axes *a_axes_;
analysis::layout *layouts_;
analysis::align *alignment_;
transform::coalesce *reorder_;
target *tgt_;
std::map<unsigned, distributed_axis> axes_;
Value *sh_mem_ptr_;

View File

@@ -38,7 +38,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
double total_time = 0;
op();
stream->synchronize();
while(total_time*1e-9 < 1e-3){
while(total_time*1e-9 < 1e-2){
float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))

View File

@@ -270,9 +270,9 @@ std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator*
}
if(x->is_int_add_sub()){
unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst)
if(lhs_cst_info[d].num_cst > 0)
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
if(rhs_cst_info[d].num_cst)
if(rhs_cst_info[d].num_cst > 0)
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
value = std::max(lvalue, rvalue);
}

View File

@@ -16,22 +16,6 @@ namespace analysis{
axes::axes() {}
void axes::add_constraint(node_t x, node_t y) {
size_t shape_x = 1;
size_t shape_y = 1;
if(x.first->get_type()->is_tile_ty())
shape_x = x.first->get_type()->get_tile_shapes()[x.second];
if(y.first->get_type()->is_tile_ty())
shape_y = y.first->get_type()->get_tile_shapes()[y.second];
if(shape_x == 1 && shape_y == 1)
return;
dependencies_[x].insert(y);
dependencies_[y].insert(x);
nodes_.insert(x);
nodes_.insert(y);
}
void axes::update_graph_reduce(ir::instruction *i) {
auto* red = static_cast<ir::reduce_inst*>(i);
unsigned axis = red->get_axis();
@@ -41,7 +25,7 @@ void axes::update_graph_reduce(ir::instruction *i) {
for(unsigned d = 0; d < in_shapes.size(); d++){
if(d == axis)
continue;
add_constraint({i, current++}, {arg, d});
graph_.add_edge({i, current++}, {arg, d});
}
}
@@ -59,9 +43,9 @@ void axes::update_graph_reshape(ir::instruction *i) {
bool same_shape = res_shapes[d] == op_shapes[current];
// either add edge between axis or just add a node in the graph
if(!is_skewed && same_shape)
add_constraint({i, d}, {op, current++});
graph_.add_edge({i, d}, {op, current++});
else
add_constraint({i, d}, {i, d});
graph_.add_edge({i, d}, {i, d});
// reshaping is skewed
if(res_shapes[d] > 1 && !same_shape)
is_skewed = true;
@@ -74,7 +58,7 @@ void axes::update_graph_trans(ir::instruction *i) {
auto perm = trans->get_perm();
// add edge between axis perm[d] and axis d
for(unsigned d = 0; d < perm.size(); d++)
add_constraint({i, perm[d]}, {op, d});
graph_.add_edge({i, perm[d]}, {op, d});
}
void axes::update_graph_broadcast(ir::instruction *i) {
@@ -86,7 +70,7 @@ void axes::update_graph_broadcast(ir::instruction *i) {
// add edge between non-broadcast axes
for(unsigned d = 0; d < shapes.size(); d ++)
if(op_shapes[d] == shapes[d])
add_constraint({i, d}, {op, d});
graph_.add_edge({i, d}, {op, d});
}
void axes::update_graph_dot(ir::instruction *i) {
@@ -97,11 +81,11 @@ void axes::update_graph_dot(ir::instruction *i) {
ir::value *D = dot->get_operand(2);
// add edges between result and accumulator
for(unsigned d = 0; d < shapes.size(); d++)
add_constraint({dot, d}, {D, d});
graph_.add_edge({dot, d}, {D, d});
// add edge for batch dimension
for(unsigned d = 2; d < shapes.size(); d++){
add_constraint({dot, d}, {A, d});
add_constraint({dot, d}, {B, d});
graph_.add_edge({dot, d}, {A, d});
graph_.add_edge({dot, d}, {B, d});
}
}
@@ -116,8 +100,8 @@ void axes::update_graph_elementwise(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
if(!i->get_type()->is_void_ty())
add_constraint({i, d}, {opx, d});
add_constraint({opx, d}, {opy, d});
graph_.add_edge({i, d}, {opx, d});
graph_.add_edge({opx, d}, {opy, d});
}
}
@@ -136,41 +120,19 @@ void axes::update_graph(ir::instruction *i) {
return;
}
void axes::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
groups_[x.first].insert({x.second, group_id});
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
for(const node_t &y: graph[x])
connected_components(y, nodes, graph, group_id);
}
}
unsigned axes::get_id(ir::value *value, unsigned ax) {
unsigned result = groups_.at(value).at(ax);
return result;
unsigned axes::get_id(ir::value *value, unsigned dim) {
return axes_.at({value, dim});
}
bool axes::has_id(ir::value *value, unsigned ax) {
auto it = groups_.find(value);
if(it == groups_.end())
return false;
auto iit = it->second.find(ax);
if(iit == it->second.end())
return false;
return true;
}
void axes::run(ir::module &mod) {
nodes_.clear();
dependencies_.clear();
groups_.clear();
// make graph
ir::for_each_instruction(mod, [this](ir::instruction *x) { update_graph(x); });
// connected components
unsigned group_id = 0;
while(!nodes_.empty())
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
graph_.clear();
ir::for_each_instruction(mod, [this](ir::instruction *x) {
update_graph(x);
});
// find connected components
graph_.connected_components(nullptr, &axes_);
}
}

View File

@@ -21,36 +21,24 @@ std::set<int> layout::axes_of(ir::value *value) {
// create result
std::set<int> result;
for(size_t d = 0; d < rank; d++)
if(axes_->has_id(value, d))
result.insert(axes_->get_id(value, d));
result.insert(axes_->get_id(value, d));
return result;
}
// connected components
void layout::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
groups_[x] = group_id;
values_[group_id].push_back(x);
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
for(const node_t &y: graph[x])
connected_components(y, nodes, graph, group_id);
}
}
// constructor
layout::layout(analysis::axes *axes)
: axes_(axes) { }
// get group id
unsigned layout::id(ir::value *value) const
unsigned layout::layout_of(ir::value *value) const
{ return groups_.at(value); }
// get values
const std::vector<ir::value*>& layout::values(unsigned id) const
const std::vector<ir::value*>& layout::values_of(unsigned id) const
{ return values_.at(id); }
// get number of groups
size_t layout::get_num_groups() const
size_t layout::num_layouts() const
{ return values_.size(); }
// connect two values
@@ -67,12 +55,8 @@ void layout::connect(ir::value *x, ir::value *y) {
std::set_intersection(x_axes.begin(), x_axes.end(),
y_axes.begin(), y_axes.end(),
std::inserter(common, common.begin()));
if(!common.empty()){
nodes_.insert(x);
nodes_.insert(y);
dependencies_[x].insert(y);
dependencies_[y].insert(x);
}
if(!common.empty())
graph_.add_edge(x, y);
}
// make graph
@@ -84,19 +68,16 @@ void layout::make_graph(ir::instruction *i) {
}
}
// run
void layout::run(ir::module &mod) {
nodes_.clear();
dependencies_.clear();
groups_.clear();
values_.clear();
// make graph
ir::for_each_instruction(mod, [this](ir::instruction* i) { make_graph(i); });
graph_.clear();
ir::for_each_instruction(mod, [this](ir::instruction* i) {
make_graph(i);
});
// connected components
unsigned group_id = 0;
while(!nodes_.empty()){
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
}
values_.clear();
groups_.clear();
graph_.connected_components(&values_, &groups_);
}
}

View File

@@ -53,42 +53,20 @@ void liveness::extract_double_bufferable(ir::instruction *i) {
void liveness::make_graph(ir::instruction *i) {
if(has_double(i)){
ir::value *latch = double_[i].latch;
nodes_.insert(i);
nodes_.insert(latch);
graph_[i].insert(latch);
graph_[latch].insert(i);
graph_.add_edge(i, latch);
}
if(i->get_id() == ir::INST_PHI){
ir::phi_node* phi = (ir::phi_node*)i;
for(ir::value* op: phi->ops()){
if(storage_info.at(i->get_id()).first == SHARED){
graph_.add_edge(i, i);
for(ir::value* op: i->ops()){
auto* iop = dynamic_cast<ir::instruction*>(op);
if(!iop || storage_info.at(iop->get_id()).first != SHARED)
continue;
nodes_.insert(phi);
nodes_.insert(op);
graph_[phi].insert(op);
graph_[op].insert(phi);
graph_.add_edge(i, op);
}
}
if(i->get_id() == ir::INST_TRANS){
nodes_.insert(i);
nodes_.insert(i->get_operand(0));
graph_[i].insert(i->get_operand(0));
graph_[i->get_operand(0)].insert(i);
}
}
// connected components
void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, buffer_t* buffer) {
groups_[x] = buffer;
values_[buffer].push_back(x);
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
for(const node_t &y: graph[x])
connected_components(y, nodes, graph, buffer);
}
}
bool is_trans(ir::value *v) {
if(dynamic_cast<ir::trans_inst *>(v)) {
return true;
@@ -121,12 +99,14 @@ bool liveness::do_pad(ir::value *x) {
pad_[b] = std::max<int>(pad_[b], (24 - b_shapes[b_row ? 1 : 0]) % 32);
return a_previous != pad_[a] || b_previous != pad_[b];
}
// padding for trans
if(auto* trans = dynamic_cast<ir::trans_inst*>(x)) {
ir::value *op = trans->get_operand(0);
size_t previous = pad_[op];
pad_[op] = std::max(pad_[op], pad_[x]);
return previous != pad_[op];
}
// padding for copy to shared
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
auto cts_order = tiles_->order(cts);
ir::value *arg = cts->get_operand(0);
@@ -187,7 +167,7 @@ void liveness::run(ir::module &mod) {
indices.clear();
pad_.clear();
intervals_.clear();
parents_.clear();
graph_.clear();
// Create set of pair of values that can be double-buffered
ir::for_each_instruction(mod, [this](ir::instruction* i) {
@@ -209,12 +189,16 @@ void liveness::run(ir::module &mod) {
});
// connected components
unsigned group_id = 0;
while(!nodes_.empty()){
buffer_t* buffer = new buffer_t{group_id++};
connected_components(*nodes_.begin(), nodes_, graph_, buffer);
for(ir::value *v: values_.at(buffer))
tools::graph<node_t>::cmap_t cmap;
tools::graph<node_t>::nmap_t nmap;
graph_.connected_components(&cmap, &nmap);
for(auto x: cmap) {
buffer_t* buffer = new buffer_t{x.first};
values_[buffer] = x.second;
for(ir::value *v: x.second){
buffer->size = std::max<int>(buffer->size, num_bytes(v));
groups_[v] = buffer;
}
}
// Assigns index to each instruction
@@ -245,6 +229,8 @@ void liveness::run(ir::module &mod) {
intervals_[x.first] = segment{start, end};
}
}
}

View File

@@ -74,7 +74,7 @@ bool is_hmma_b_row(ir::value* v) {
layout_t tiles::hmma(ir::value *value) {
return hmma_.at(layout_->id(value));
return hmma_.at(layout_->layout_of(value));
}
int tiles::mts(ir::value *value, unsigned ax) {
@@ -94,7 +94,7 @@ int tiles::wpt(ir::value *value, unsigned ax) {
}
std::vector<int> tiles::order(ir::value *v) {
auto ret = order_[layout_->id(v)];
auto ret = order_[layout_->layout_of(v)];
return ret;
}
@@ -201,7 +201,9 @@ bool tiles::is_trans(ir::value *v) {
void tiles::run(ir::module &) {
hmma_.clear();
largest_.clear();
size_t num_groups = layout_->get_num_groups();
order_.clear();
size_t num_groups = layout_->num_layouts();
// helpers
auto rank = [](ir::value* v) {
ir::type *ty = v->get_type();
@@ -213,7 +215,7 @@ void tiles::run(ir::module &) {
};
// find out which groups require hmma layout
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i);
const auto& values = layout_->values_of(i);
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
if(hmma_c) hmma_[i] = HMMA_C;
else hmma_[i] = SCANLINE;
@@ -221,7 +223,7 @@ void tiles::run(ir::module &) {
}
// find out which value is the largest in each group
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i);
const auto& values = layout_->values_of(i);
auto cmp = [&rank](ir::value* x, ir::value *y) { return rank(x) < rank(y); };
largest_[i] = *std::max_element(values.begin(), values.end(), cmp);
}
@@ -230,7 +232,7 @@ void tiles::run(ir::module &) {
// find out the layout ordering of a group
for(size_t i = 0; i < num_groups; i++){
std::set<ir::io_inst*> io;
for(ir::value* v: layout_->values(i))
for(ir::value* v: layout_->values_of(i))
extract_io_use(v, io);
auto cmp = [&rank](ir::io_inst* x, ir::io_inst *y) {
return rank(x->get_pointer_operand()) < rank(y->get_pointer_operand());
@@ -249,27 +251,27 @@ void tiles::run(ir::module &) {
// matrix multiplication optimizations
for(size_t i = 0; i < num_groups; i++){
std::vector<ir::dot_inst*> dots;
for(ir::value* v: layout_->values(i))
for(ir::value* v: layout_->values_of(i))
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
dots.push_back(x);
for(ir::dot_inst* dot: dots){
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
if(hmma_.at(layout_->id(dot)) == HMMA_C){
auto a_val = layout_->values(layout_->id(a));
auto b_val = layout_->values(layout_->id(b));
if(hmma_.at(layout_->layout_of(dot)) == HMMA_C){
auto a_val = layout_->values_of(layout_->layout_of(a));
auto b_val = layout_->values_of(layout_->layout_of(b));
for(ir::value *v: a_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
order_[layout_->id(a)] = order_[layout_->id(cts->get_operand(0))];
order_[layout_->layout_of(a)] = order_[layout_->layout_of(cts->get_operand(0))];
for(ir::value *v: b_val)
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
order_[layout_->id(b)] = order_[layout_->id(cts->get_operand(0))];
order_[layout_->layout_of(b)] = order_[layout_->layout_of(cts->get_operand(0))];
}
else{
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
order_[layout_->id(a)] = is_trans(a) ? row : col;
order_[layout_->id(b)] = is_trans(b) ? col : row;
order_[layout_->layout_of(a)] = is_trans(a) ? row : col;
order_[layout_->layout_of(b)] = is_trans(b) ? col : row;
}
}
}

View File

@@ -67,10 +67,10 @@ ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
void coalesce::run(ir::module &mod) {
// find values to rematerialize
size_t num_groups = layout_->get_num_groups();
size_t num_groups = layout_->num_layouts();
std::vector<ir::io_inst*> remat;
for(size_t id = 0; id < num_groups; id++) {
const auto& values = layout_->values(id);
const auto& values = layout_->values_of(id);
// extract pointers used in ld/st operations
std::set<ir::io_inst*> io;
for(ir::value *v: values)

View File

@@ -269,7 +269,10 @@ void reassociate::run(ir::module &mod) {
it++;
builder.set_insert_point(*it);
}
ir::value *neg_off = builder.create_neg(off);
ir::value *_0 = builder.get_int32(0);
if(off->get_type()->is_tile_ty())
_0 = builder.create_splat(_0, off->get_type()->get_tile_shapes());
ir::value *neg_off = builder.create_sub(_0, off);
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
infos[phi_sta].dyn_ptr = phi_dyn;

View File

@@ -200,11 +200,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
// create passes
codegen::transform::cts cts;
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::analysis::layout layouts(&axes);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
codegen::analysis::liveness liveness(&tiles);
codegen::analysis::allocation allocation(&liveness, &tiles);
@@ -212,11 +210,12 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&align);
codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::cts cts;
codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, target.get(), opt.num_warps);
// run passes
peephole.run(module);
dce.run(module);
// ir::print(module, std::cout);
align.run(module);
cts.run(module);
axes.run(module);
@@ -225,11 +224,15 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
dce.run(module);
align.run(module);
dce.run(module);
tiles.run(module);
reassociate.run(module);
// ir::print(module, std::cout);
// exit(EXIT_FAILURE);
dce.run(module);
cts.run(module);
// ir::print(module, std::cout);
align.run(module);
axes.run(module);
layouts.run(module);
tiles.run(module);
liveness.run(module);
allocation.run(module);
if(allocation.allocated_size() > context->device()->max_shared_memory())
@@ -238,10 +241,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
dce.run(module);
axes.run(module);
layouts.run(module);
align.run(module);
// ir::print(module, std::cout);
tiles.run(module);
// ir::print(module, std::cout);
align.run(module);
tiles.run(module);
selection.run(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));

View File

@@ -34,7 +34,7 @@ int main() {
for(const auto& c: configs){
std::tie(ord, AT, BT, M, N, K) = c;
std::cout << "// " << c << std::flush;
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}

View File

@@ -109,10 +109,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
opt.num_warps = {nwarp};
}
if(mode == BENCH) {
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"16"}});
opt.num_warps = {4};
opt.defines.push_back({"TM", {"64", "128"}});
opt.defines.push_back({"TN", {"64", "128"}});
opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {2, 4, 8};
}
// kernels

View File

@@ -13,7 +13,7 @@ int main() {
for(int TM: std::vector<int>{32, 64})
for(int TN: std::vector<int>{32, 64})
for(int TK: std::vector<int>{8})
for(int nwarps: std::vector<int>{1, 2, 4, 8})
for(int nwarps: std::vector<int>{1, 4})
for(bool AT: std::array<bool, 2>{false, true})
for(bool BT: std::array<bool, 2>{false, true}){
configs.push_back(config_t{FLOAT, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
@@ -29,7 +29,6 @@ int main() {
std::cout << " Pass! " << std::endl;
else{
std::cout << " Fail! " << std::endl;
exit(EXIT_FAILURE);
}
}
}