[general] various cleaning and bugfix:
* added copy1d and copy2d benchmark * fixed issue in reassociation pass
This commit is contained in:
@@ -13,7 +13,7 @@ namespace ir {
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class alignment_info {
|
||||
class align {
|
||||
struct cst_info {
|
||||
unsigned num_cst;
|
||||
unsigned value;
|
@@ -13,11 +13,10 @@ namespace ir{
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
typedef unsigned slot_index;
|
||||
|
||||
class info;
|
||||
class meminfo;
|
||||
|
||||
struct segment {
|
||||
slot_index start;
|
||||
@@ -45,7 +44,7 @@ public:
|
||||
|
||||
public:
|
||||
// constructor
|
||||
liveness(info *info): info_(info){ }
|
||||
liveness(meminfo *info): info_(info){ }
|
||||
|
||||
// accessors
|
||||
const intervals_map_t& intervals() const { return intervals_; }
|
||||
@@ -55,7 +54,7 @@ public:
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
info *info_;
|
||||
meminfo *info_;
|
||||
has_storage_map_t has_dedicated_storage_;
|
||||
indices_map_t indices_;
|
||||
intervals_map_t intervals_;
|
||||
@@ -64,7 +63,6 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -17,14 +17,12 @@ namespace analysis{
|
||||
|
||||
class grids;
|
||||
|
||||
namespace shmem{
|
||||
|
||||
class liveness;
|
||||
class info;
|
||||
class meminfo;
|
||||
|
||||
class allocation {
|
||||
class memalloc {
|
||||
public:
|
||||
allocation(liveness *live, info *buffer_info, grids *params)
|
||||
memalloc(liveness *live, meminfo *buffer_info, grids *params)
|
||||
: liveness_(live), buffer_info_(buffer_info), params_(params){ }
|
||||
|
||||
// utilities
|
||||
@@ -44,13 +42,12 @@ private:
|
||||
size_t allocated_size_;
|
||||
// dependences
|
||||
liveness *liveness_;
|
||||
info *buffer_info_;
|
||||
meminfo *buffer_info_;
|
||||
grids *params_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -15,9 +15,8 @@ namespace ir {
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
class info {
|
||||
class meminfo {
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
// queries
|
||||
@@ -38,6 +37,5 @@ private:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -5,7 +5,7 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/codegen/analysis/meminfo.h"
|
||||
|
||||
|
||||
namespace llvm{
|
||||
@@ -45,14 +45,10 @@ namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class grids;
|
||||
class alignment_info;
|
||||
class align;
|
||||
class memalloc;
|
||||
class meminfo;
|
||||
|
||||
namespace shmem{
|
||||
|
||||
class allocation;
|
||||
class info;
|
||||
|
||||
}
|
||||
}
|
||||
class target;
|
||||
|
||||
@@ -196,7 +192,7 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
selection(analysis::shmem::allocation *alloc, analysis::grids *params, analysis::shmem::info *buffer_info, analysis::alignment_info *alignment, target *tgt)
|
||||
selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, target *tgt)
|
||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), tgt_(tgt){ }
|
||||
|
||||
void run(ir::module &src, Module &dst);
|
||||
@@ -204,10 +200,10 @@ public:
|
||||
private:
|
||||
vmap_t vmap_;
|
||||
tmap_t tmap_;
|
||||
analysis::shmem::allocation *alloc_;
|
||||
analysis::memalloc *alloc_;
|
||||
analysis::grids *params_;
|
||||
analysis::shmem::info *buffer_info_;
|
||||
analysis::alignment_info *alignment_;
|
||||
analysis::meminfo *buffer_info_;
|
||||
analysis::align *alignment_;
|
||||
target *tgt_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
Value *sh_mem_ptr_;
|
@@ -46,6 +46,7 @@ public:
|
||||
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual unsigned guaranteed_alignment() = 0;
|
||||
bool is_gpu() const;
|
||||
|
||||
private:
|
||||
@@ -62,6 +63,7 @@ public:
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
};
|
||||
|
||||
class nvidia_cu_target: public target {
|
||||
@@ -74,6 +76,7 @@ public:
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
};
|
||||
|
||||
class cpu_target: public target {
|
||||
@@ -86,6 +89,7 @@ public:
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 1; }
|
||||
};
|
||||
|
||||
}
|
@@ -14,17 +14,15 @@ namespace ir {
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
class allocation;
|
||||
class info;
|
||||
class memalloc;
|
||||
class meminfo;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class shmem_barriers {
|
||||
class membar {
|
||||
private:
|
||||
typedef std::pair<unsigned, unsigned> interval_t;
|
||||
typedef std::vector<interval_t> interval_vec_t;
|
||||
@@ -40,12 +38,12 @@ private:
|
||||
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set<ir::instruction *> &insert_loc);
|
||||
|
||||
public:
|
||||
shmem_barriers(analysis::shmem::allocation *alloc, analysis::shmem::info *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
|
||||
membar(analysis::memalloc *alloc, analysis::meminfo *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::shmem::allocation *alloc_;
|
||||
analysis::shmem::info *buffer_info_;
|
||||
analysis::memalloc *alloc_;
|
||||
analysis::meminfo *buffer_info_;
|
||||
};
|
||||
|
||||
|
@@ -20,7 +20,7 @@ namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class grids;
|
||||
class alignment_info;
|
||||
class align;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
@@ -37,12 +37,12 @@ private:
|
||||
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
|
||||
|
||||
public:
|
||||
reassociate(analysis::alignment_info* align, analysis::grids *params);
|
||||
reassociate(analysis::align* align, analysis::grids *params);
|
||||
void run(ir::module& module);
|
||||
|
||||
private:
|
||||
analysis::grids* params_;
|
||||
analysis::alignment_info* align_;
|
||||
analysis::align* align_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -61,6 +61,19 @@ public:
|
||||
return kind_ != multiple_of;
|
||||
}
|
||||
|
||||
std::string repr() const {
|
||||
switch(kind_){
|
||||
case readonly: return ".readonly";
|
||||
case writeonly: return ".writeonly";
|
||||
case noalias: return ".noalias";
|
||||
case aligned: return ".aligned(" + std::to_string(value_) + ")";
|
||||
case multiple_of: return ".readonly";
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return "";
|
||||
}
|
||||
|
||||
private:
|
||||
attribute_kind_t kind_;
|
||||
unsigned value_;
|
||||
|
@@ -687,7 +687,7 @@ private:
|
||||
public:
|
||||
static nv_static_program_idx *get(constant_range* range);
|
||||
constant_range* get_range() const;
|
||||
std::string repr() const { return get_name(); }
|
||||
std::string repr() const { return "nv_static_program_idx"; }
|
||||
|
||||
private:
|
||||
constant_range *range_;
|
||||
|
@@ -9,16 +9,16 @@
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
// codegen
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/codegen/selection/target.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/codegen/analysis/shmem/allocation.h"
|
||||
#include "triton/codegen/analysis/shmem/liveness.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/codegen/analysis/alignment.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/codegen/analysis/memalloc.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/meminfo.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/shmem/barriers.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/transform/vectorize.h"
|
||||
#include "triton/lang/parser.h"
|
||||
|
@@ -41,8 +41,8 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
tmr.start();
|
||||
op();
|
||||
stream->synchronize();
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#include "triton/codegen/analysis/alignment.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -29,14 +29,14 @@ inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
|
||||
}
|
||||
|
||||
|
||||
bool alignment_info::is_first_axis_unit(ir::value *x){
|
||||
bool align::is_first_axis_unit(ir::value *x){
|
||||
if(x->get_type()->is_tile_ty())
|
||||
return x->get_type()->get_tile_shapes()[0] == 1;
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) {
|
||||
align::cst_info align::populate_is_constant(ir::value *v) {
|
||||
if(is_constant_.find(v) != is_constant_.end())
|
||||
return is_constant_.at(v);
|
||||
// helper for the cache
|
||||
@@ -102,7 +102,7 @@ alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) {
|
||||
return cache({1, 0});
|
||||
}
|
||||
|
||||
unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
||||
unsigned align::populate_max_contiguous(ir::value *v){
|
||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||
return max_contiguous_.at(v);
|
||||
// helper for the cache
|
||||
@@ -181,7 +181,7 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){
|
||||
return cache(1);
|
||||
}
|
||||
|
||||
unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||
unsigned align::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
auto cache = [this,v](unsigned value){
|
||||
@@ -240,7 +240,19 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||
int rhs = populate_starting_multiple(x->get_operand(1));
|
||||
return cache(gcd(lhs, rhs));
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v)){
|
||||
int op = populate_starting_multiple(x->get_operand(0));
|
||||
return cache(op);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v)){
|
||||
int op = populate_starting_multiple(x->get_operand(0));
|
||||
auto shapes = x->get_type()->get_tile_shapes();
|
||||
if(shapes[0] == 1)
|
||||
return cache(1);
|
||||
else
|
||||
return cache(op);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v)){
|
||||
int op = populate_starting_multiple(x->get_operand(0));
|
||||
return cache(op);
|
||||
}
|
||||
@@ -271,22 +283,22 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||
return cache(result);
|
||||
}
|
||||
|
||||
unsigned alignment_info::get_starting_multiple(ir::value* v) const {
|
||||
unsigned align::get_starting_multiple(ir::value* v) const {
|
||||
return starting_multiple_.at(v);
|
||||
}
|
||||
|
||||
unsigned alignment_info::get_max_contiguous(ir::value* v) const {
|
||||
unsigned align::get_max_contiguous(ir::value* v) const {
|
||||
return max_contiguous_.at(v);
|
||||
}
|
||||
|
||||
void alignment_info::copy(ir::value *dst, ir::value *src) {
|
||||
void align::copy(ir::value *dst, ir::value *src) {
|
||||
starting_multiple_[dst] = starting_multiple_[src];
|
||||
max_contiguous_[dst] = max_contiguous_[src];
|
||||
is_constant_[dst] = is_constant_[src];
|
||||
}
|
||||
|
||||
///TODO: This doesn't seem to work in DOT-NN, DOT-TT, DOT-TN
|
||||
void alignment_info::run(ir::module &mod) {
|
||||
void align::run(ir::module &mod) {
|
||||
// populate constant
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
@@ -304,9 +316,13 @@ void alignment_info::run(ir::module &mod) {
|
||||
// populate maximum contiguous
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
populate_max_contiguous(i);
|
||||
}
|
||||
|
||||
// for(ir::function *fn: mod.get_function_list())
|
||||
// for(ir::basic_block *block: fn->blocks())
|
||||
// for(ir::instruction *i: block->get_inst_list())
|
||||
// std::cout << i->get_name() << " " << max_contiguous_.at(i) << " " << is_constant_.at(i).num_cst << " " << starting_multiple_.at(i) << std::endl;
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -292,7 +292,7 @@ void grids::run(ir::module &mod) {
|
||||
else{
|
||||
unsigned shape = shapes[0];
|
||||
unsigned current = num_threads;
|
||||
params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 8));
|
||||
params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 4));
|
||||
params_.at(i).at("mts.d0")->set_value(clamp(current, 1, shape / params_.at(i).at("nts.d0")->get_value()));
|
||||
current = current / params_.at(i).at("mts.d0")->get_value();
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
@@ -1,5 +1,5 @@
|
||||
#include "triton/codegen/analysis/shmem/liveness.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/meminfo.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -9,7 +9,6 @@
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
// Entry point
|
||||
void liveness::run(ir::module &mod) {
|
||||
@@ -41,4 +40,3 @@ void liveness::run(ir::module &mod) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,8 +1,8 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/analysis/shmem/allocation.h"
|
||||
#include "triton/codegen/analysis/shmem/liveness.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/codegen/analysis/memalloc.h"
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/meminfo.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/value.h"
|
||||
@@ -12,9 +12,8 @@
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
unsigned allocation::is_ld_padded(ir::value *x) {
|
||||
unsigned memalloc::is_ld_padded(ir::value *x) {
|
||||
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
|
||||
if(trans->get_perm()[0]->get_value() != 0)
|
||||
return 4;
|
||||
@@ -46,7 +45,7 @@ unsigned allocation::is_ld_padded(ir::value *x) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned allocation::get_num_bytes(ir::value *x) {
|
||||
unsigned memalloc::get_num_bytes(ir::value *x) {
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
|
||||
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
size_t axis = red->get_axis();
|
||||
@@ -74,7 +73,7 @@ unsigned allocation::get_num_bytes(ir::value *x) {
|
||||
return num_bytes;
|
||||
}
|
||||
|
||||
void allocation::run(){
|
||||
void memalloc::run(){
|
||||
using std::max;
|
||||
using std::min;
|
||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||
@@ -178,4 +177,3 @@ void allocation::run(){
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,5 +1,5 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/codegen/analysis/meminfo.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -10,10 +10,9 @@ namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
// run pass on module
|
||||
bool info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
bool meminfo::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
if(phi->get_parent() != terminator->get_parent())
|
||||
return false;
|
||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||
@@ -25,7 +24,7 @@ bool info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
void info::replace(ir::value* before, ir::value *after) {
|
||||
void meminfo::replace(ir::value* before, ir::value *after) {
|
||||
shared_.erase(before);
|
||||
shared_.insert(after);
|
||||
if(refs_.find(before) != refs_.end()){
|
||||
@@ -72,7 +71,7 @@ void add_copy(ir::value *x, ir::builder &builder) {
|
||||
}
|
||||
}
|
||||
|
||||
void info::run(ir::module &mod) {
|
||||
void meminfo::run(ir::module &mod) {
|
||||
// Add shared copies
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
ir::builder builder(mod.get_context());
|
||||
@@ -122,15 +121,15 @@ void info::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
// query double-buffered status
|
||||
bool info::is_double(ir::value *x)
|
||||
bool meminfo::is_double(ir::value *x)
|
||||
{ return double_.find(x) != double_.end(); }
|
||||
|
||||
// query shared status
|
||||
bool info::is_shared(ir::value *x)
|
||||
bool meminfo::is_shared(ir::value *x)
|
||||
{ return shared_.find(x) != shared_.end(); }
|
||||
|
||||
// get reference if any
|
||||
ir::value *info::get_reference(ir::value *x)
|
||||
ir::value *meminfo::get_reference(ir::value *x)
|
||||
{ return refs_[x]; }
|
||||
|
||||
|
||||
@@ -138,4 +137,3 @@ ir::value *info::get_reference(ir::value *x)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,8 +1,8 @@
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/codegen/analysis/shmem/allocation.h"
|
||||
#include "triton/codegen/selection/target.h"
|
||||
#include "triton/codegen/analysis/alignment.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/codegen/analysis/memalloc.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
@@ -1304,10 +1304,7 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
|
||||
unsigned id = linear / vector_size;
|
||||
if(linear % vector_size == 0) {
|
||||
Value *ptr = pointers->get_value(idx);
|
||||
// ConstantInt *cst = nullptr;
|
||||
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
// if(gep->getNumIndices() == 1)
|
||||
// cst = dyn_cast<ConstantInt>(gep->idx_begin());
|
||||
|
||||
|
||||
ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size),
|
||||
ptr->getType()->getPointerAddressSpace()));
|
||||
@@ -1326,23 +1323,28 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
|
||||
((PHINode*)current_result)->addIncoming(result_then, mask_then_bb);
|
||||
Value *result_false = false_values->get_value(idx);
|
||||
if(result_then->getType()->isVectorTy())
|
||||
result_false = builder.CreateVectorSplat(vector_size, result_false);
|
||||
result_false = builder.CreateVectorSplat(vector_size, llvm::UndefValue::get(result_false->getType()));
|
||||
((PHINode*)current_result)->addIncoming(result_false, current_bb);
|
||||
}
|
||||
else
|
||||
current_result = result_then;
|
||||
|
||||
// ConstantInt *cst = nullptr;
|
||||
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
// if(gep->getNumIndices() == 1)
|
||||
// cst = dyn_cast<ConstantInt>(gep->idx_begin());
|
||||
// llvm::Value* mask = masks->get_value(idx);
|
||||
// std::string offset = "";
|
||||
// if(cst)
|
||||
// offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size);
|
||||
// Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
||||
// Type *fp16x2_pack4_ty = StructType::get(ctx, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||
// FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false);
|
||||
// std::string asm_str = "@$0 ld.global.nc.v4.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
|
||||
// if(false_value)
|
||||
// std::string asm_str = "@$0 ld.global.nc.b32 {$1, $2, $3, $4}, [$5" + offset + "];";
|
||||
// if(false_values)
|
||||
// asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};";
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true);
|
||||
// Value *result = builder.CreateCall(iasm, {mask, ptr});
|
||||
// Value *current_result = builder.CreateCall(iasm, {mask, ptr});
|
||||
|
||||
packets[id] = current_result;
|
||||
}
|
||||
@@ -1499,9 +1501,11 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
for(auto attr_pair: fn->attrs()){
|
||||
unsigned id = attr_pair.first;
|
||||
for(ir::attribute attr: attr_pair.second)
|
||||
if(attr.is_llvm_attr())
|
||||
if(attr.is_llvm_attr()){
|
||||
dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr));
|
||||
}
|
||||
}
|
||||
|
||||
tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
|
||||
// set metadata
|
||||
Metadata *md_args[] = {
|
@@ -1,4 +1,4 @@
|
||||
#include "triton/codegen/selection/target.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
@@ -2,9 +2,9 @@
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
|
||||
#include "triton/codegen/transform/shmem/barriers.h"
|
||||
#include "triton/codegen/analysis/shmem/allocation.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/analysis/memalloc.h"
|
||||
#include "triton/codegen/analysis/meminfo.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -16,7 +16,7 @@ namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
bool shmem_barriers::intersect(const interval_vec_t &X, interval_t x) {
|
||||
bool membar::intersect(const interval_vec_t &X, interval_t x) {
|
||||
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
|
||||
bool left_intersect = y.first <= x.first && x.first < y.second;
|
||||
bool right_intersect = y.first <= x.second && x.second < y.second;
|
||||
@@ -24,13 +24,13 @@ bool shmem_barriers::intersect(const interval_vec_t &X, interval_t x) {
|
||||
});
|
||||
}
|
||||
|
||||
bool shmem_barriers::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
|
||||
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
|
||||
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
|
||||
return intersect(X, y);
|
||||
});
|
||||
}
|
||||
|
||||
void shmem_barriers::add_reference(ir::value *v, interval_vec_t &res){
|
||||
void membar::add_reference(ir::value *v, interval_vec_t &res){
|
||||
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::phi_node*>(v)){
|
||||
unsigned offset = alloc_->get_offset(v);
|
||||
unsigned num_bytes = alloc_->get_num_bytes(v);
|
||||
@@ -38,17 +38,17 @@ void shmem_barriers::add_reference(ir::value *v, interval_vec_t &res){
|
||||
}
|
||||
}
|
||||
|
||||
void shmem_barriers::get_read_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
for(ir::value *op: i->ops())
|
||||
add_reference(op, res);
|
||||
}
|
||||
|
||||
void shmem_barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
if(!dynamic_cast<ir::phi_node*>(i))
|
||||
add_reference(i, res);
|
||||
}
|
||||
|
||||
void shmem_barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
||||
std::set<ir::value*> incoming;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
@@ -67,16 +67,16 @@ void shmem_barriers::insert_barrier(ir::instruction *instr, ir::builder &builder
|
||||
}
|
||||
}
|
||||
|
||||
shmem_barriers::interval_vec_t shmem_barriers::join(const std::vector<interval_vec_t>& intervals) {
|
||||
shmem_barriers::interval_vec_t result;
|
||||
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
|
||||
membar::interval_vec_t result;
|
||||
for(auto x: intervals)
|
||||
for(interval_t i: x)
|
||||
result.push_back(i);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<shmem_barriers::interval_vec_t,
|
||||
shmem_barriers::interval_vec_t> shmem_barriers::transfer(ir::basic_block *block,
|
||||
std::pair<membar::interval_vec_t,
|
||||
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
|
||||
const interval_vec_t &written_to,
|
||||
const interval_vec_t &read_from,
|
||||
std::set<ir::instruction*>& insert_loc) {
|
||||
@@ -104,7 +104,7 @@ std::pair<shmem_barriers::interval_vec_t,
|
||||
return std::make_pair(new_written_to, new_read_from);
|
||||
}
|
||||
|
||||
void shmem_barriers::run(ir::module &mod) {
|
||||
void membar::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
@@ -1,7 +1,8 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/analysis/alignment.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -161,7 +162,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
return new_value;
|
||||
}
|
||||
|
||||
reassociate::reassociate(analysis::alignment_info *align, analysis::grids* params)
|
||||
reassociate::reassociate(analysis::align *align, analysis::grids* params)
|
||||
: params_(params), align_(align)
|
||||
{ }
|
||||
|
||||
@@ -209,6 +210,29 @@ void reassociate::run(ir::module &mod) {
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// retiling
|
||||
if(ir::retile_inst *rt = dynamic_cast<ir::retile_inst*>(i)) {
|
||||
ir::value* op = rt->get_operand(0);
|
||||
if(infos.find(op) != infos.end()){
|
||||
builder.set_insert_point(rt);
|
||||
ir::getelementptr_inst* sta = infos.at(op).sta_ptr;
|
||||
ir::value* dyn = infos.at(op).dyn_ptr;
|
||||
ir::value* cst = *sta->idx_begin();
|
||||
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
|
||||
auto shapes = rt->get_type()->get_tile_shapes();
|
||||
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
|
||||
ir::value* broadcast = builder.create_broadcast(cst, shapes);
|
||||
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
|
||||
params_->copy(ndyn, rt);
|
||||
params_->copy(nsta, rt);
|
||||
params_->copy(broadcast, rt);
|
||||
align_->copy(ndyn, rt);
|
||||
align_->copy(nsta, rt);
|
||||
align_->copy(broadcast, rt);
|
||||
infos[rt] = cst_info{ndyn, nsta};
|
||||
}
|
||||
}
|
||||
}
|
||||
// getelementptr instruction
|
||||
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
|
||||
if(replaced.find(pz) != replaced.end())
|
||||
|
@@ -1,5 +1,5 @@
|
||||
#include "triton/codegen/transform/vectorize.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
|
@@ -27,7 +27,7 @@
|
||||
#include <memory>
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/codegen/selection/target.h"
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
|
@@ -223,12 +223,12 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"])->setValue(true);
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_70", "", buffer, "", Assembly);
|
||||
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_60", "", buffer, "", Assembly);
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
size_t start_replace = result.find(".version");
|
||||
size_t end_replace = result.find('\n', start_replace);
|
||||
assert(start_replace != std::string::npos);
|
||||
result.replace(start_replace, end_replace - start_replace, ".version 6.4");
|
||||
result.replace(start_replace, end_replace - start_replace, ".version 6.0");
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -245,10 +245,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
|
||||
try{
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
||||
}catch(exception::cuda::base const &){
|
||||
#ifdef TRITON_LOG_PTX_ERROR
|
||||
//#ifdef TRITON_LOG_PTX_ERROR
|
||||
std::cerr << "Compilation Failed! Log: " << std::endl;
|
||||
std::cerr << errbuf << std::endl;
|
||||
#endif
|
||||
//#endif
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
@@ -29,6 +29,7 @@ void module::set_value(const std::string& name, ir::basic_block *block, ir::valu
|
||||
if(it != metadatas_.end()){
|
||||
x->set_metadata(it->second.first, it->second.second);
|
||||
}
|
||||
value->set_name(name);
|
||||
}
|
||||
|
||||
void module::set_value(const std::string& name, ir::value *value){
|
||||
|
@@ -22,6 +22,18 @@ std::string get_name(ir::value *v, unsigned i) {
|
||||
void print(module &mod, std::ostream& os) {
|
||||
unsigned cnt = 0;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
os << "def " << fn->get_fn_type()->get_return_ty()->repr() << " " << fn->get_name() << "(" ;
|
||||
for(ir::argument* arg: fn->args()) {
|
||||
if(arg->get_arg_no() > 0)
|
||||
os << ", ";
|
||||
os << arg->get_type()->repr() << " " << arg->get_name();
|
||||
auto attrs = fn->get_attributes(arg);
|
||||
if(attrs.size() > 0)
|
||||
os << " ";
|
||||
for(ir::attribute attr: attrs)
|
||||
os << attr.repr() << " ";
|
||||
}
|
||||
os << ")" << std::endl;
|
||||
os << "{" << std::endl;
|
||||
for(ir::basic_block *block: fn->blocks()){
|
||||
auto const &predecessors = block->get_predecessors();
|
||||
|
@@ -373,8 +373,11 @@ void Generator::VisitFuncDef(FuncDef* funcDef) {
|
||||
for(Object* obj: type->Params()){
|
||||
std::string name = obj->Name();
|
||||
args[i]->set_name(name);
|
||||
for(ASTNode::Attr attr: obj->GetAttrList())
|
||||
fn->add_attr(i, GenIRAttr(attr));
|
||||
if(obj->Type()->ToPointer())
|
||||
fn->add_attr(i + 1, ir::attribute(ir::aligned, 16));
|
||||
for(ASTNode::Attr attr: obj->GetAttrList()){
|
||||
fn->add_attr(i + 1, GenIRAttr(attr));
|
||||
}
|
||||
if(obj->IsRestrictQualified())
|
||||
fn->add_attr(i, ir::attribute(ir::noalias));
|
||||
mod_->set_value(name, nullptr, args[i]);
|
||||
|
@@ -3,7 +3,7 @@
|
||||
#include <regex>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
#include "triton/lang/parser.h"
|
||||
@@ -167,8 +167,6 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
||||
bin = make_bin(*ir, stream->context(), opt);
|
||||
}catch(const std::runtime_error& e) {
|
||||
return;
|
||||
}catch(const driver::exception::cuda::invalid_ptx& e) {
|
||||
return;
|
||||
}
|
||||
// benchmark
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
@@ -191,23 +189,31 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
std::unique_ptr<codegen::target> target = context->device()->make_target();
|
||||
// create passes
|
||||
codegen::analysis::grids grids(opt.num_warps);
|
||||
codegen::analysis::shmem::info shmem_info;
|
||||
codegen::analysis::shmem::liveness shmem_liveness(&shmem_info);
|
||||
codegen::analysis::shmem::allocation shmem_allocation(&shmem_liveness, &shmem_info, &grids);
|
||||
codegen::analysis::alignment_info alignment_info;
|
||||
codegen::transform::shmem_barriers shmem_barriers(&shmem_allocation, &shmem_info);
|
||||
codegen::analysis::meminfo shmem_info;
|
||||
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
||||
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
|
||||
codegen::analysis::align alignment_info;
|
||||
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
||||
codegen::transform::vectorize vectorize(&grids);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
|
||||
|
||||
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
grids.run(module);
|
||||
alignment_info.run(module);
|
||||
grids.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
|
||||
reassociate.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
|
||||
peephole.run(module);
|
||||
|
||||
if(target->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
@@ -217,7 +223,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
dce.run(module);
|
||||
vectorize.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
|
||||
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
|
@@ -1,4 +1,4 @@
|
||||
foreach(PROG dot)
|
||||
foreach(PROG dot copy1d copy2d)
|
||||
set(TARGET bench_${PROG})
|
||||
add_executable(${TARGET} ${PROG}.cc)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
|
||||
|
@@ -17,7 +17,7 @@ inline size_t ceil(size_t x, size_t y) {
|
||||
return (x + y - 1) / y;
|
||||
};
|
||||
|
||||
inline rt::function::grid_fn_ty grid(size_t M, size_t N) {
|
||||
inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
||||
return [M, N](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||
ceil(N, x.D<int>("TN"))};
|
||||
@@ -42,11 +42,9 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
||||
// create options
|
||||
rt::function::options_space_t opt;
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
if(AT)
|
||||
opt.defines.push_back({"AT", {""}});
|
||||
if(BT)
|
||||
opt.defines.push_back({"BT", {""}});
|
||||
opt.defines.push_back({"TM", {"64"}});
|
||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"64"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.num_warps = {4};
|
||||
@@ -55,18 +53,18 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
||||
// benchmark available libraries
|
||||
std::vector<double> result;
|
||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||
// // cublas
|
||||
// if(cublas::cublasinit()){
|
||||
// NumericT alpha(static_cast<double>(1));
|
||||
// NumericT beta(static_cast<double>(0));
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
|
||||
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, nullptr, fastest); }, stream);
|
||||
// result.push_back(tflops(cublas_ms));
|
||||
// }
|
||||
// cublas
|
||||
if(cublas::cublasinit()){
|
||||
NumericT alpha(static_cast<double>(1));
|
||||
NumericT beta(static_cast<double>(0));
|
||||
cublasGemmAlgo_t fastest = CUBLAS_GEMM_ALGO5;
|
||||
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||
double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K,
|
||||
&alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, nullptr, fastest); }, stream);
|
||||
result.push_back(tflops(cublas_ms));
|
||||
}
|
||||
// triton
|
||||
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream);}, stream);
|
||||
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid2d(M, N), stream);}, stream);
|
||||
result.push_back(tflops(triton_ms));
|
||||
// done
|
||||
return result;
|
||||
@@ -79,11 +77,9 @@ int main() {
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
|
||||
{false, true},
|
||||
{true, false}}){
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, true}}){
|
||||
std::vector<config_t> tmp = {
|
||||
config_t{x[0], x[1], 8192, 8192, 8192}
|
||||
config_t{x[0], x[1], 2048, 2048, 2048}
|
||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||
// config_t{x[0], x[1], 32, 2048, 2048},
|
||||
// config_t{x[0], x[1], 64, 2048, 2048},
|
||||
|
@@ -24,11 +24,11 @@ typedef enum{
|
||||
typedef enum {
|
||||
CUBLAS_GEMM_DFALT = -1,
|
||||
CUBLAS_GEMM_DEFAULT = -1,
|
||||
CUBLAS_GEMM_ALGO0 = 0,
|
||||
CUBLAS_GEMM_ALGO1 = 1,
|
||||
CUBLAS_GEMM_ALGO2 = 2,
|
||||
CUBLAS_GEMM_ALGO3 = 3,
|
||||
CUBLAS_GEMM_ALGO4 = 4,
|
||||
CUBLAS_GEMM_ALGO0 = 0, // maxwell_sgemm_32x128_nt
|
||||
CUBLAS_GEMM_ALGO1 = 1, // maxwell_sgemm_64x64_nt
|
||||
CUBLAS_GEMM_ALGO2 = 2, // maxwell_sgemm_128x32_nt
|
||||
CUBLAS_GEMM_ALGO3 = 3, // maxwell_sgemm_128x64_nt
|
||||
CUBLAS_GEMM_ALGO4 = 4, // maxwell_sgemm_128x128_nt
|
||||
CUBLAS_GEMM_ALGO5 = 5,
|
||||
CUBLAS_GEMM_ALGO6 = 6,
|
||||
CUBLAS_GEMM_ALGO7 = 7,
|
||||
@@ -102,4 +102,4 @@ typedef enum {
|
||||
CUBLAS_TENSOR_OP_MATH = 1
|
||||
} cublasMath_t;
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
@@ -2,33 +2,33 @@ namespace src {
|
||||
|
||||
const char *dot =
|
||||
R"(
|
||||
#ifdef AT
|
||||
#if AT == 1
|
||||
#define USEA ^a
|
||||
#define STRIDE_AK lda
|
||||
#define STRIDE_AM 1
|
||||
#define STRIDE_AK 1
|
||||
#define STRIDE_AM lda
|
||||
#define BROADCAST_AK :, newaxis
|
||||
#define BROADCAST_AM newaxis, :
|
||||
#define SHAPE_A TK, TM
|
||||
#else
|
||||
#define USEA a
|
||||
#define STRIDE_AK 1
|
||||
#define STRIDE_AM lda
|
||||
#define STRIDE_AK lda
|
||||
#define STRIDE_AM 1
|
||||
#define BROADCAST_AK newaxis, :
|
||||
#define BROADCAST_AM :, newaxis
|
||||
#define SHAPE_A TM, TK
|
||||
#endif
|
||||
|
||||
#ifdef BT
|
||||
#if BT == 1
|
||||
#define USEB ^b
|
||||
#define STRIDE_BK 1
|
||||
#define STRIDE_BN ldb
|
||||
#define STRIDE_BK ldb
|
||||
#define STRIDE_BN 1
|
||||
#define BROADCAST_BK newaxis, :
|
||||
#define BROADCAST_BN :, newaxis
|
||||
#define SHAPE_B TN, TK
|
||||
#else
|
||||
#define USEB b
|
||||
#define STRIDE_BK ldb
|
||||
#define STRIDE_BN 1
|
||||
#define STRIDE_BK 1
|
||||
#define STRIDE_BN ldb
|
||||
#define BROADCAST_BK :, newaxis
|
||||
#define BROADCAST_BN newaxis, :
|
||||
#define SHAPE_B TK, TN
|
||||
@@ -58,17 +58,15 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
c += USEA @ USEB;
|
||||
pa = pa + TK * STRIDE_AK;
|
||||
pb = pb + TK * STRIDE_BK;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
a = ((bool[SHAPE_A])(k > TK)) ? *pa : 0;
|
||||
b = ((bool[SHAPE_B])(k > TK)) ? *pb : 0;
|
||||
}
|
||||
// epilogue
|
||||
int rxc[TM] = ridx * TM + 0 ... TM;
|
||||
int ryc[TN] = ridy * TN + 0 ... TN;
|
||||
TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc;
|
||||
bool checkc[TM, TN] = (rxc < M)[:, newaxis] && (ryc < N)[newaxis, :];
|
||||
*?(checkc) pc = c;
|
||||
TYPE* pc[TM, TN] = C + ryc[newaxis, :] * ldc + rxc[:, newaxis];
|
||||
*pc = c;
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
}
|
||||
|
@@ -3,21 +3,35 @@
|
||||
#ifndef _TRITON_TESTS_UTIL_H
|
||||
#define _TRITON_TESTS_UTIL_H
|
||||
|
||||
#include <iomanip>
|
||||
#include "triton/runtime/function.h"
|
||||
|
||||
namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
inline size_t ceil(size_t x, size_t y) {
|
||||
return (x + y - 1) / y;
|
||||
};
|
||||
|
||||
inline rt::function::grid_fn_ty grid(size_t M, size_t N) {
|
||||
inline rt::function::grid_fn_ty grid1d(size_t N) {
|
||||
return [N](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(N, x.D<int>("TN"))};
|
||||
};
|
||||
}
|
||||
|
||||
inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
||||
return [M, N](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||
ceil(N, x.D<int>("TN"))};
|
||||
};
|
||||
}
|
||||
|
||||
enum order_t {
|
||||
ROWMAJOR,
|
||||
COLMAJOR
|
||||
};
|
||||
|
||||
|
||||
namespace aux{
|
||||
template<std::size_t...> struct seq{};
|
||||
|
||||
@@ -51,11 +65,14 @@ namespace testing {
|
||||
if(hc.size() != rc.size())
|
||||
return false;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2)
|
||||
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
@@ -1,5 +1,5 @@
|
||||
foreach(PROG dot)
|
||||
set(TARGET test_${PROG})
|
||||
set(TARGET unit_${PROG})
|
||||
add_executable(${TARGET} ${PROG}.cc)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
|
||||
target_link_libraries(${TARGET} triton dl)
|
||||
|
@@ -51,8 +51,8 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
|
||||
|
||||
|
||||
bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, size_t nwarp){
|
||||
typedef half_float::half NumericT;
|
||||
std::string ty = "half";
|
||||
typedef float NumericT;
|
||||
std::string ty = "float";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
drv::context* context = stream->context();
|
||||
std::vector<NumericT> hc(M*N);
|
||||
@@ -78,17 +78,15 @@ bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_
|
||||
// run
|
||||
rt::function::options_space_t opt;
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
if(AT)
|
||||
opt.defines.push_back({"AT", {""}});
|
||||
if(BT)
|
||||
opt.defines.push_back({"BT", {""}});
|
||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||
opt.defines.push_back({"TM", {std::to_string(TM)}});
|
||||
opt.defines.push_back({"TN", {std::to_string(TN)}});
|
||||
opt.defines.push_back({"TK", {std::to_string(TK)}});
|
||||
opt.num_warps = {nwarp};
|
||||
rt::function function(src::dot, opt);
|
||||
try {
|
||||
function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream);
|
||||
function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid2d(M, N), stream);
|
||||
} catch (const std::runtime_error& e) {
|
||||
return true;
|
||||
}
|
||||
|
Reference in New Issue
Block a user