[general][filesystem] added structure and namespace to code generation files
This commit is contained in:
@@ -40,7 +40,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
hb[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = static_cast<NumericT>((double)0);
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*dt_nbytes);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*dt_nbytes);
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*dt_nbytes);
|
||||
stream->write(da, true, 0, ha);
|
||||
@@ -49,7 +49,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
// NumericT alpha = 1;
|
||||
// NumericT beta = 0;
|
||||
|
@@ -130,7 +130,7 @@ public:
|
||||
// create profile
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP);
|
||||
// blocksparse matmul
|
||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
|
||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::PARTIAL_TUNING);
|
||||
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
|
||||
Tensor *tmp = nullptr;
|
||||
TensorShape tmp_shapes;
|
||||
|
@@ -12,6 +12,7 @@ namespace ir {
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class alignment_info {
|
||||
struct cst_info {
|
||||
@@ -41,6 +42,7 @@ private:
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -13,16 +13,18 @@ namespace ir{
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class layout;
|
||||
class target_tuner;
|
||||
class shmem_liveness;
|
||||
class shmem_info;
|
||||
class tune;
|
||||
|
||||
class shmem_allocation {
|
||||
namespace shmem{
|
||||
|
||||
class liveness;
|
||||
class info;
|
||||
|
||||
class allocation {
|
||||
public:
|
||||
shmem_allocation(shmem_liveness *live, shmem_info *buffer_info, tune *params)
|
||||
allocation(liveness *live, info *buffer_info, tune *params)
|
||||
: liveness_(live), buffer_info_(buffer_info), params_(params){ }
|
||||
|
||||
// utilities
|
||||
@@ -41,11 +43,13 @@ private:
|
||||
std::map<ir::value*, unsigned> num_bytes_;
|
||||
size_t allocated_size_;
|
||||
// dependences
|
||||
shmem_liveness *liveness_;
|
||||
shmem_info *buffer_info_;
|
||||
liveness *liveness_;
|
||||
info *buffer_info_;
|
||||
tune *params_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -14,8 +14,10 @@ namespace ir {
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
class shmem_info {
|
||||
class info {
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
// queries
|
||||
@@ -33,7 +35,8 @@ private:
|
||||
std::map<ir::value*, ir::value*> refs_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -12,10 +12,12 @@ namespace ir{
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
typedef unsigned slot_index;
|
||||
|
||||
class shmem_info;
|
||||
class info;
|
||||
|
||||
struct segment {
|
||||
slot_index start;
|
||||
@@ -30,7 +32,7 @@ struct segment {
|
||||
}
|
||||
};
|
||||
|
||||
class shmem_liveness {
|
||||
class liveness {
|
||||
private:
|
||||
typedef std::map<ir::value*, slot_index> indices_map_t;
|
||||
typedef std::map<ir::value*, segment> intervals_map_t;
|
||||
@@ -43,7 +45,7 @@ public:
|
||||
|
||||
public:
|
||||
// constructor
|
||||
shmem_liveness(shmem_info *info): info_(info){ }
|
||||
liveness(info *info): info_(info){ }
|
||||
|
||||
// accessors
|
||||
const intervals_map_t& intervals() const { return intervals_; }
|
||||
@@ -53,7 +55,7 @@ public:
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
shmem_info *info_;
|
||||
info *info_;
|
||||
has_storage_map_t has_dedicated_storage_;
|
||||
indices_map_t indices_;
|
||||
intervals_map_t intervals_;
|
||||
@@ -61,5 +63,8 @@ private:
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -16,6 +16,7 @@ namespace ir{
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class tune {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
@@ -67,6 +68,7 @@ private:
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -7,7 +7,7 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
|
||||
|
||||
namespace llvm{
|
||||
@@ -21,12 +21,20 @@ namespace llvm{
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
class shmem_allocation;
|
||||
namespace analysis{
|
||||
|
||||
class tune;
|
||||
class shmem_info;
|
||||
class target;
|
||||
class alignment_info;
|
||||
|
||||
namespace shmem{
|
||||
|
||||
class allocation;
|
||||
class info;
|
||||
|
||||
}
|
||||
}
|
||||
class target;
|
||||
|
||||
typedef std::vector<llvm::Value*> indices_t;
|
||||
|
||||
struct distributed_axis {
|
||||
@@ -138,7 +146,7 @@ private:
|
||||
void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
|
||||
|
||||
public:
|
||||
selection(shmem_allocation *alloc, tune *params, shmem_info *buffer_info, alignment_info *ax_info, target *tgt)
|
||||
selection(analysis::shmem::allocation *alloc, analysis::tune *params, analysis::shmem::info *buffer_info, analysis::alignment_info *ax_info, target *tgt)
|
||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), axis_info_(ax_info), tgt_(tgt){ }
|
||||
|
||||
void run(ir::module &src, llvm::Module &dst);
|
||||
@@ -146,11 +154,11 @@ public:
|
||||
private:
|
||||
vmap_t vmap_;
|
||||
tmap_t tmap_;
|
||||
shmem_allocation *alloc_;
|
||||
tune *params_;
|
||||
analysis::shmem::allocation *alloc_;
|
||||
analysis::tune *params_;
|
||||
target *tgt_;
|
||||
shmem_info *buffer_info_;
|
||||
alignment_info *axis_info_;
|
||||
analysis::shmem::info *buffer_info_;
|
||||
analysis::alignment_info *axis_info_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
llvm::Value *sh_mem_ptr_;
|
||||
llvm::Value *offset_a_i_, *offset_a_k_;
|
@@ -12,7 +12,7 @@ namespace ir {
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
class tune;
|
||||
namespace transform{
|
||||
|
||||
class optimize_dce {
|
||||
public:
|
||||
@@ -20,7 +20,7 @@ public:
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -13,18 +13,22 @@ namespace ir {
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class tune;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class optimize_dot {
|
||||
public:
|
||||
optimize_dot(tune* params): params_(params) {}
|
||||
optimize_dot(analysis::tune* params): params_(params) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
tune* params_;
|
||||
analysis::tune* params_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -19,8 +19,12 @@ class getelementptr_inst;
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class tune;
|
||||
class alignment_info;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class reassociate {
|
||||
struct cst_info {
|
||||
@@ -34,16 +38,18 @@ private:
|
||||
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
|
||||
|
||||
public:
|
||||
reassociate(tune *params, alignment_info *align);
|
||||
reassociate(analysis::tune *params, analysis::alignment_info *align);
|
||||
void run(ir::module& module);
|
||||
|
||||
private:
|
||||
tune* params_;
|
||||
alignment_info* align_;
|
||||
analysis::tune* params_;
|
||||
analysis::alignment_info* align_;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -17,8 +17,16 @@ namespace ir {
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class shmem_allocation;
|
||||
class shmem_info;
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
class allocation;
|
||||
class info;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class shmem_barriers {
|
||||
private:
|
||||
@@ -36,15 +44,16 @@ private:
|
||||
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set<ir::instruction *> &insert_loc);
|
||||
|
||||
public:
|
||||
shmem_barriers(shmem_allocation *alloc, shmem_info *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
|
||||
shmem_barriers(analysis::shmem::allocation *alloc, analysis::shmem::info *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
shmem_allocation *alloc_;
|
||||
shmem_info *buffer_info_;
|
||||
analysis::shmem::allocation *alloc_;
|
||||
analysis::shmem::info *buffer_info_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -17,6 +17,7 @@ namespace ir {
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class optimize_trans {
|
||||
private:
|
||||
@@ -28,6 +29,7 @@ public:
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -9,18 +9,22 @@ namespace ir {
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class tune;
|
||||
namespace analysis{
|
||||
class tune;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class vectorize {
|
||||
public:
|
||||
vectorize(tune *params): params_(params){}
|
||||
vectorize(analysis::tune *params): params_(params){}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
tune *params_;
|
||||
analysis::tune *params_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -8,19 +8,19 @@
|
||||
#include "triton/ir/print.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/optimize_dot.h"
|
||||
#include "triton/codegen/optimize_dce.h"
|
||||
#include "triton/codegen/optimize_trans.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/shmem_liveness.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/shmem_barriers.h"
|
||||
#include "triton/codegen/alignment_info.h"
|
||||
#include "triton/codegen/reassociate.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/codegen/selection/target.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/codegen/analysis/shmem/allocation.h"
|
||||
#include "triton/codegen/analysis/shmem/liveness.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/codegen/analysis/alignment.h"
|
||||
#include "triton/codegen/transform/dot.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/trans.h"
|
||||
#include "triton/codegen/transform/shmem/barriers.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/transform/vectorize.h"
|
||||
#include "triton/runtime/launch_info.h"
|
||||
#include <functional>
|
||||
|
||||
@@ -35,8 +35,10 @@ class translation_unit;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class tune;
|
||||
}
|
||||
}
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
@@ -90,18 +92,18 @@ public:
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
codegen::tune tune;
|
||||
codegen::shmem_info shmem_info;
|
||||
codegen::shmem_liveness shmem_liveness;
|
||||
codegen::shmem_allocation shmem_allocation;
|
||||
codegen::shmem_barriers shmem_barriers;
|
||||
codegen::vectorize vectorize;
|
||||
codegen::selection selection;
|
||||
codegen::optimize_dot optimize_dot;
|
||||
codegen::optimize_dce optimize_dce;
|
||||
codegen::optimize_trans optimize_trans;
|
||||
codegen::alignment_info alignment_info;
|
||||
codegen::reassociate reassociate;
|
||||
codegen::analysis::tune tune;
|
||||
codegen::analysis::shmem::info shmem_info;
|
||||
codegen::analysis::shmem::liveness shmem_liveness;
|
||||
codegen::analysis::shmem::allocation shmem_allocation;
|
||||
codegen::analysis::alignment_info alignment_info;
|
||||
codegen::transform::shmem_barriers shmem_barriers;
|
||||
codegen::transform::vectorize vectorize;
|
||||
codegen::transform::optimize_dot optimize_dot;
|
||||
codegen::transform::optimize_dce optimize_dce;
|
||||
codegen::transform::optimize_trans optimize_trans;
|
||||
codegen::transform::reassociate reassociate;
|
||||
codegen::target* target_;
|
||||
};
|
||||
|
||||
|
@@ -38,8 +38,8 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
|
||||
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
|
||||
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
tmr.start();
|
||||
op();
|
||||
stream->synchronize();
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#include "triton/codegen/alignment_info.h"
|
||||
#include "triton/codegen/analysis/alignment.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
inline int gcd(int a, int b) {
|
||||
@@ -310,3 +311,4 @@ void alignment_info::run(ir::module &mod) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,7 +1,7 @@
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/shmem_liveness.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/analysis/shmem/allocation.h"
|
||||
#include "triton/codegen/analysis/shmem/liveness.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/value.h"
|
||||
@@ -10,8 +10,10 @@
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
unsigned shmem_allocation::is_ld_padded(ir::value *x) {
|
||||
unsigned allocation::is_ld_padded(ir::value *x) {
|
||||
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
|
||||
if(trans->get_perm()[0]->get_value() != 0)
|
||||
return 4;
|
||||
@@ -43,7 +45,7 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
|
||||
unsigned allocation::get_num_bytes(ir::value *x) {
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
|
||||
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
size_t axis = red->get_axis();
|
||||
@@ -71,7 +73,7 @@ unsigned shmem_allocation::get_num_bytes(ir::value *x) {
|
||||
return num_bytes;
|
||||
}
|
||||
|
||||
void shmem_allocation::run(){
|
||||
void allocation::run(){
|
||||
using std::max;
|
||||
using std::min;
|
||||
typedef std::multimap<unsigned, segment> triples_map_type;
|
||||
@@ -174,3 +176,5 @@ void shmem_allocation::run(){
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -8,10 +8,11 @@
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
// run pass on module
|
||||
bool shmem_info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
bool info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
if(phi->get_parent() != terminator->get_parent())
|
||||
return false;
|
||||
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
|
||||
@@ -23,7 +24,7 @@ bool shmem_info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
void shmem_info::replace(ir::value* before, ir::value *after) {
|
||||
void info::replace(ir::value* before, ir::value *after) {
|
||||
shared_.erase(before);
|
||||
shared_.insert(after);
|
||||
if(refs_.find(before) != refs_.end()){
|
||||
@@ -70,7 +71,7 @@ void add_copy(ir::value *x, ir::builder &builder) {
|
||||
}
|
||||
}
|
||||
|
||||
void shmem_info::run(ir::module &mod) {
|
||||
void info::run(ir::module &mod) {
|
||||
// Add shared copies
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
ir::builder builder(mod.get_context());
|
||||
@@ -120,18 +121,20 @@ void shmem_info::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
// query double-buffered status
|
||||
bool shmem_info::is_double(ir::value *x)
|
||||
bool info::is_double(ir::value *x)
|
||||
{ return double_.find(x) != double_.end(); }
|
||||
|
||||
// query shared status
|
||||
bool shmem_info::is_shared(ir::value *x)
|
||||
bool info::is_shared(ir::value *x)
|
||||
{ return shared_.find(x) != shared_.end(); }
|
||||
|
||||
// get reference if any
|
||||
ir::value *shmem_info::get_reference(ir::value *x)
|
||||
ir::value *info::get_reference(ir::value *x)
|
||||
{ return refs_[x]; }
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,5 +1,5 @@
|
||||
#include "triton/codegen/shmem_liveness.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/analysis/shmem/liveness.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -8,10 +8,11 @@
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
namespace shmem{
|
||||
|
||||
// Entry point
|
||||
void shmem_liveness::run(ir::module &mod) {
|
||||
void liveness::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Assigns index to each instruction
|
||||
slot_index index = 0;
|
||||
@@ -39,3 +40,5 @@ void shmem_liveness::run(ir::module &mod) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -12,6 +12,7 @@
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
tune::tune(): num_global_ranges_(0){ }
|
||||
|
||||
@@ -257,7 +258,7 @@ void tune::run(ir::module &mod) {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
if(node.second == 2)
|
||||
fpw->set_value(1);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
}
|
||||
@@ -270,7 +271,8 @@ void tune::run(ir::module &mod) {
|
||||
continue;
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(i)){
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(i))
|
||||
if(fragments_.at({i, 0}) == STRIDED_SCAN){
|
||||
ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty();
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
@@ -452,3 +454,4 @@ unsigned tune::get_num_threads() {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,15 +1,15 @@
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/alignment_info.h"
|
||||
#include "llvm/IR/InstrTypes.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/codegen/analysis/shmem/allocation.h"
|
||||
#include "triton/codegen/selection/target.h"
|
||||
#include "triton/codegen/analysis/alignment.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "llvm/IR/InstrTypes.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/Transforms/Scalar/EarlyCSE.h"
|
||||
#include "llvm/Analysis/LoopInfo.h"
|
||||
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
||||
@@ -485,7 +485,7 @@ inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw,
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
size_t dim = shapes.size();
|
||||
if(params_->get_fragment(v, 0) == tune::STRIDED_SCAN){
|
||||
if(params_->get_fragment(v, 0) == analysis::tune::STRIDED_SCAN){
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
std::vector<unsigned> block_size(dim);
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
@@ -1022,7 +1022,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
{
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN) {
|
||||
if(params_->get_fragment(ins, 0) == analysis::tune::STRIDED_SCAN) {
|
||||
TA->set_vector_size(TC->axis(0).contiguous);
|
||||
TB->set_vector_size(TC->axis(1).contiguous);
|
||||
result->for_each([&](indices_t idx){
|
||||
@@ -1047,8 +1047,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
a = builder.CreateFPCast(a, c_ty);
|
||||
if(b->getType() != c_ty)
|
||||
b = builder.CreateFPCast(b, c_ty);
|
||||
// a = ConstantFP::get(builder.getFloatTy(), 1);
|
||||
// b = ConstantFP::get(builder.getFloatTy(), 1);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
@@ -1060,13 +1058,14 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
TA->set_return_mode(true);
|
||||
TB->set_return_mode(true);
|
||||
|
||||
std::map<Value*, std::vector<Value*>> fcs;
|
||||
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
|
||||
|
||||
result->for_each([&](indices_t idx){
|
||||
fcs[{builder.getInt32(0)}].push_back(TC->get_value(idx));
|
||||
std::vector<Value*> key(idx.size() - 2);
|
||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||
fcs[key].push_back(TC->get_value(idx));
|
||||
});
|
||||
|
||||
|
||||
Type *fp32_ty = builder.getFloatTy();
|
||||
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
||||
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||
@@ -1122,8 +1121,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
std::swap(idx_a[0], idx_a[1]);
|
||||
if(!dot->is_b_trans())
|
||||
std::swap(idx_b[0], idx_b[1]);
|
||||
// idx_a.push_back(builder.getInt32(0));
|
||||
// idx_b.push_back(builder.getInt32(0));
|
||||
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
|
||||
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
|
||||
Value *ha = TA->get_value(idx_a);
|
||||
Value *hb = TB->get_value(idx_b);
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
@@ -1159,9 +1158,11 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// write back
|
||||
unsigned i = 0;
|
||||
result->for_each([&](indices_t idx){
|
||||
if(i >= fcs.at({builder.getInt32(0)}).size())
|
||||
std::vector<Value*> key(idx.size() - 2);
|
||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||
if(i >= fcs.at(key).size())
|
||||
i = 0;
|
||||
result->set_value(idx, fcs.at({builder.getInt32(0)})[i++]);
|
||||
result->set_value(idx, fcs.at(key)[i++]);
|
||||
});
|
||||
|
||||
TA->set_return_mode(false);
|
@@ -1,4 +1,4 @@
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/selection/target.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
@@ -2,10 +2,11 @@
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
#include "triton/codegen/optimize_dce.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
void optimize_dce::run(ir::module &mod) {
|
||||
@@ -60,3 +61,4 @@ void optimize_dce::run(ir::module &mod) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,11 +1,12 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/optimize_dot.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/transform/dot.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
inline bool is_trans(ir::value *v){
|
||||
auto *x = dynamic_cast<ir::trans_inst*>(v);
|
||||
@@ -109,3 +110,4 @@ void optimize_dot::run(ir::module &mod) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,15 +1,16 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/reassociate.h"
|
||||
#include "triton/codegen/alignment_info.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/analysis/alignment.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
//inline Constant *get_gep_cst_offset(GetElementPtrInst *gep){
|
||||
// std::vector<Value*> idx_vals;
|
||||
@@ -154,7 +155,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
return new_value;
|
||||
}
|
||||
|
||||
reassociate::reassociate(tune* params, alignment_info* align)
|
||||
reassociate::reassociate(analysis::tune* params, analysis::alignment_info* align)
|
||||
: params_(params), align_(align)
|
||||
{ }
|
||||
|
||||
@@ -280,3 +281,4 @@ void reassociate::run(ir::module &mod) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,7 +1,7 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/shmem_barriers.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/transform/shmem/barriers.h"
|
||||
#include "triton/codegen/analysis/shmem/allocation.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -11,6 +11,7 @@
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
bool shmem_barriers::intersect(const interval_vec_t &X, interval_t x) {
|
||||
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
|
||||
@@ -137,3 +138,4 @@ void shmem_barriers::run(ir::module &mod) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,9 +1,10 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/codegen/optimize_trans.h"
|
||||
#include "triton/codegen/transform/trans.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
ir::value* optimize_trans::replace_phi(ir::value* value,
|
||||
@@ -73,3 +74,4 @@ void optimize_trans::run(ir::module &mod) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,5 +1,5 @@
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/transform/vectorize.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -8,6 +8,7 @@
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void vectorize::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
@@ -39,3 +40,4 @@ void vectorize::run(ir::module &mod) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -62,7 +62,8 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else{
|
||||
params_t params = heuristics();
|
||||
// params_t params = heuristics();
|
||||
params_t params = {4, 2, 16, 4, 4, 16, 2, 2, 1, 1, 1, 8, 64, 8, 8, 1, 4, 2, 1};
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
|
@@ -96,7 +96,7 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
)" + c_ty_ + R"(* C,
|
||||
int lda, int ldc, int N,
|
||||
int* lut, int* locks, int nlocks){
|
||||
int* lut, int* locks, int nlocks) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
float acc[TM, TN] = 0;
|
||||
@@ -129,10 +129,10 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
)" + c_ty_ + R"(" c[TM, TN] = acc;
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
|
||||
bool checkc[TM, TN] = (rxc < N)[:, newaxis];
|
||||
if(lockid == 0)
|
||||
if(lockid == 0) {
|
||||
@checkc *pc = c;
|
||||
else
|
||||
{
|
||||
}
|
||||
else {
|
||||
int *plock = locks + ridx*nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_program(0)*nlocks;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
@@ -147,10 +147,11 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
__atomic_exch(plock, 0);
|
||||
}
|
||||
})";
|
||||
|
||||
os << result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -75,14 +75,14 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string ZS = "1";
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK", XAS2 = ZS;
|
||||
std::string XBS0 = "TK", XBS1 = ZS, XBS2 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
|
||||
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT_ ? "trans(xa, 2, 0, 1)" : "xa";
|
||||
std::string useb = BT_ ? "trans(xb, 1, 0, 2)" : "trans(xb, 0, 2, 1)";
|
||||
std::string usea = AT_ ? "trans(a)" : "a";
|
||||
std::string useb = BT_ ? "trans(b)" : "b";
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(XAS0, XAS1);
|
||||
@@ -99,15 +99,15 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
}
|
||||
std::string AS = AS0 + ", " + AS1;
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
|
||||
std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
|
||||
std::string XCS = "TM, TN, " + ZS;
|
||||
// std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
|
||||
// std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
|
||||
std::string XCS = "TM, TN";
|
||||
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int TM = {16, 32, 64, 128};
|
||||
const tunable int TN = {16, 32, 64, 128};
|
||||
const tunable int TM = {32};
|
||||
const tunable int TN = {32};
|
||||
const tunable int TK = {32};
|
||||
const tunable int GZ = {1};
|
||||
|
||||
@@ -131,8 +131,6 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
bool checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
|
||||
)" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"();
|
||||
)" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"();
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = dot()" + usea + ", " + useb + R"(, xc);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
@@ -141,13 +139,11 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
bool checkb[)" + BS + R"(] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
xa = __reshape(a, )" + XAS + R"();
|
||||
xb = __reshape(b, )" + XBS + R"();
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty_ + R"( c[TM, TN] = __sum(xc, 2);
|
||||
)" + c_ty_ + R"( c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
|
@@ -28,7 +28,7 @@
|
||||
#include "triton/driver/helpers/CL/infos.hpp"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/selection/target.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
|
@@ -1,6 +1,6 @@
|
||||
#include <string>
|
||||
#include "triton/lang/lang.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/selection/target.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/driver/device.h"
|
||||
|
Reference in New Issue
Block a user