[general][filesystem] added structure and namespace to code generation files

This commit is contained in:
Philippe Tillet
2019-08-07 21:15:54 -07:00
parent 392b55280d
commit 7578c27d3d
35 changed files with 224 additions and 147 deletions

View File

@@ -40,7 +40,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
hb[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
for(size_t i = 0; i < hc.size(); i++)
hc[i] = static_cast<NumericT>((double)0);
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*dt_nbytes);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*dt_nbytes);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*dt_nbytes);
stream->write(da, true, 0, ha);
@@ -49,7 +49,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
// benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
// benchmark cublas
// NumericT alpha = 1;
// NumericT beta = 0;

View File

@@ -130,7 +130,7 @@ public:
// create profile
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP);
// blocksparse matmul
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::PARTIAL_TUNING);
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
Tensor *tmp = nullptr;
TensorShape tmp_shapes;

View File

@@ -12,6 +12,7 @@ namespace ir {
}
namespace codegen{
namespace analysis{
class alignment_info {
struct cst_info {
@@ -41,6 +42,7 @@ private:
};
}
}
}

View File

@@ -13,16 +13,18 @@ namespace ir{
}
namespace codegen{
namespace analysis{
class layout;
class target_tuner;
class shmem_liveness;
class shmem_info;
class tune;
class shmem_allocation {
namespace shmem{
class liveness;
class info;
class allocation {
public:
shmem_allocation(shmem_liveness *live, shmem_info *buffer_info, tune *params)
allocation(liveness *live, info *buffer_info, tune *params)
: liveness_(live), buffer_info_(buffer_info), params_(params){ }
// utilities
@@ -41,11 +43,13 @@ private:
std::map<ir::value*, unsigned> num_bytes_;
size_t allocated_size_;
// dependences
shmem_liveness *liveness_;
shmem_info *buffer_info_;
liveness *liveness_;
info *buffer_info_;
tune *params_;
};
}
}
}
}

View File

@@ -14,8 +14,10 @@ namespace ir {
}
namespace codegen{
namespace analysis{
namespace shmem{
class shmem_info {
class info {
public:
void run(ir::module &mod);
// queries
@@ -33,7 +35,8 @@ private:
std::map<ir::value*, ir::value*> refs_;
};
}
}
}
}

View File

@@ -12,10 +12,12 @@ namespace ir{
}
namespace codegen{
namespace analysis{
namespace shmem{
typedef unsigned slot_index;
class shmem_info;
class info;
struct segment {
slot_index start;
@@ -30,7 +32,7 @@ struct segment {
}
};
class shmem_liveness {
class liveness {
private:
typedef std::map<ir::value*, slot_index> indices_map_t;
typedef std::map<ir::value*, segment> intervals_map_t;
@@ -43,7 +45,7 @@ public:
public:
// constructor
shmem_liveness(shmem_info *info): info_(info){ }
liveness(info *info): info_(info){ }
// accessors
const intervals_map_t& intervals() const { return intervals_; }
@@ -53,7 +55,7 @@ public:
void run(ir::module &mod);
private:
shmem_info *info_;
info *info_;
has_storage_map_t has_dedicated_storage_;
indices_map_t indices_;
intervals_map_t intervals_;
@@ -61,5 +63,8 @@ private:
}
}
}
}
#endif

View File

@@ -16,6 +16,7 @@ namespace ir{
}
namespace codegen{
namespace analysis{
class tune {
typedef std::pair<ir::value*, unsigned> node_t;
@@ -67,6 +68,7 @@ private:
};
}
}
}

View File

@@ -7,7 +7,7 @@
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "triton/codegen/shmem_info.h"
#include "triton/codegen/analysis/shmem/info.h"
namespace llvm{
@@ -21,12 +21,20 @@ namespace llvm{
namespace triton{
namespace codegen{
class shmem_allocation;
namespace analysis{
class tune;
class shmem_info;
class target;
class alignment_info;
namespace shmem{
class allocation;
class info;
}
}
class target;
typedef std::vector<llvm::Value*> indices_t;
struct distributed_axis {
@@ -138,7 +146,7 @@ private:
void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
public:
selection(shmem_allocation *alloc, tune *params, shmem_info *buffer_info, alignment_info *ax_info, target *tgt)
selection(analysis::shmem::allocation *alloc, analysis::tune *params, analysis::shmem::info *buffer_info, analysis::alignment_info *ax_info, target *tgt)
: alloc_(alloc), params_(params), buffer_info_(buffer_info), axis_info_(ax_info), tgt_(tgt){ }
void run(ir::module &src, llvm::Module &dst);
@@ -146,11 +154,11 @@ public:
private:
vmap_t vmap_;
tmap_t tmap_;
shmem_allocation *alloc_;
tune *params_;
analysis::shmem::allocation *alloc_;
analysis::tune *params_;
target *tgt_;
shmem_info *buffer_info_;
alignment_info *axis_info_;
analysis::shmem::info *buffer_info_;
analysis::alignment_info *axis_info_;
std::map<unsigned, distributed_axis> axes_;
llvm::Value *sh_mem_ptr_;
llvm::Value *offset_a_i_, *offset_a_k_;

View File

@@ -12,7 +12,7 @@ namespace ir {
}
namespace codegen{
class tune;
namespace transform{
class optimize_dce {
public:
@@ -20,7 +20,7 @@ public:
void run(ir::module &mod);
};
}
}
}

View File

@@ -13,18 +13,22 @@ namespace ir {
namespace codegen{
namespace analysis{
class tune;
}
namespace transform{
class optimize_dot {
public:
optimize_dot(tune* params): params_(params) {}
optimize_dot(analysis::tune* params): params_(params) {}
void run(ir::module &mod);
private:
tune* params_;
analysis::tune* params_;
};
}
}
}

View File

@@ -19,8 +19,12 @@ class getelementptr_inst;
namespace codegen{
namespace analysis{
class tune;
class alignment_info;
}
namespace transform{
class reassociate {
struct cst_info {
@@ -34,16 +38,18 @@ private:
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
public:
reassociate(tune *params, alignment_info *align);
reassociate(analysis::tune *params, analysis::alignment_info *align);
void run(ir::module& module);
private:
tune* params_;
alignment_info* align_;
analysis::tune* params_;
analysis::alignment_info* align_;
};
}
}
}
#endif

View File

@@ -17,8 +17,16 @@ namespace ir {
namespace codegen{
class shmem_allocation;
class shmem_info;
namespace analysis{
namespace shmem{
class allocation;
class info;
}
}
namespace transform{
class shmem_barriers {
private:
@@ -36,15 +44,16 @@ private:
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set<ir::instruction *> &insert_loc);
public:
shmem_barriers(shmem_allocation *alloc, shmem_info *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
shmem_barriers(analysis::shmem::allocation *alloc, analysis::shmem::info *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
void run(ir::module &mod);
private:
shmem_allocation *alloc_;
shmem_info *buffer_info_;
analysis::shmem::allocation *alloc_;
analysis::shmem::info *buffer_info_;
};
}
}
}

View File

@@ -17,6 +17,7 @@ namespace ir {
}
namespace codegen{
namespace transform{
class optimize_trans {
private:
@@ -28,6 +29,7 @@ public:
};
}
}
}

View File

@@ -9,18 +9,22 @@ namespace ir {
namespace codegen{
class tune;
namespace analysis{
class tune;
}
namespace transform{
class vectorize {
public:
vectorize(tune *params): params_(params){}
vectorize(analysis::tune *params): params_(params){}
void run(ir::module &mod);
private:
tune *params_;
analysis::tune *params_;
};
}
}
}

View File

@@ -8,19 +8,19 @@
#include "triton/ir/print.h"
#include "triton/driver/module.h"
#include "triton/driver/kernel.h"
#include "triton/codegen/selection.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/optimize_dot.h"
#include "triton/codegen/optimize_dce.h"
#include "triton/codegen/optimize_trans.h"
#include "triton/codegen/shmem_allocation.h"
#include "triton/codegen/shmem_liveness.h"
#include "triton/codegen/shmem_info.h"
#include "triton/codegen/shmem_barriers.h"
#include "triton/codegen/alignment_info.h"
#include "triton/codegen/reassociate.h"
#include "triton/codegen/target.h"
#include "triton/codegen/vectorize.h"
#include "triton/codegen/selection/selection.h"
#include "triton/codegen/selection/target.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/codegen/analysis/shmem/allocation.h"
#include "triton/codegen/analysis/shmem/liveness.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/codegen/analysis/alignment.h"
#include "triton/codegen/transform/dot.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/trans.h"
#include "triton/codegen/transform/shmem/barriers.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/vectorize.h"
#include "triton/runtime/launch_info.h"
#include <functional>
@@ -35,8 +35,10 @@ class translation_unit;
}
namespace codegen{
namespace analysis{
class tune;
}
}
namespace ir {
class module;
@@ -90,18 +92,18 @@ public:
// ir::print(module, std::cout);
}
codegen::tune tune;
codegen::shmem_info shmem_info;
codegen::shmem_liveness shmem_liveness;
codegen::shmem_allocation shmem_allocation;
codegen::shmem_barriers shmem_barriers;
codegen::vectorize vectorize;
codegen::selection selection;
codegen::optimize_dot optimize_dot;
codegen::optimize_dce optimize_dce;
codegen::optimize_trans optimize_trans;
codegen::alignment_info alignment_info;
codegen::reassociate reassociate;
codegen::analysis::tune tune;
codegen::analysis::shmem::info shmem_info;
codegen::analysis::shmem::liveness shmem_liveness;
codegen::analysis::shmem::allocation shmem_allocation;
codegen::analysis::alignment_info alignment_info;
codegen::transform::shmem_barriers shmem_barriers;
codegen::transform::vectorize vectorize;
codegen::transform::optimize_dot optimize_dot;
codegen::transform::optimize_dce optimize_dce;
codegen::transform::optimize_trans optimize_trans;
codegen::transform::reassociate reassociate;
codegen::target* target_;
};

View File

@@ -38,8 +38,8 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
while(total_time*1e-9 < 1e-3){
float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();
stream->synchronize();

View File

@@ -1,4 +1,4 @@
#include "triton/codegen/alignment_info.h"
#include "triton/codegen/analysis/alignment.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -7,6 +7,7 @@
namespace triton {
namespace codegen{
namespace analysis{
inline int gcd(int a, int b) {
@@ -310,3 +311,4 @@ void alignment_info::run(ir::module &mod) {
}
}
}

View File

@@ -1,7 +1,7 @@
#include "triton/codegen/shmem_allocation.h"
#include "triton/codegen/shmem_liveness.h"
#include "triton/codegen/shmem_info.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/analysis/shmem/allocation.h"
#include "triton/codegen/analysis/shmem/liveness.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/type.h"
#include "triton/ir/value.h"
@@ -10,8 +10,10 @@
namespace triton{
namespace codegen{
namespace analysis{
namespace shmem{
unsigned shmem_allocation::is_ld_padded(ir::value *x) {
unsigned allocation::is_ld_padded(ir::value *x) {
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
if(trans->get_perm()[0]->get_value() != 0)
return 4;
@@ -43,7 +45,7 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
return 0;
}
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
unsigned allocation::get_num_bytes(ir::value *x) {
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t axis = red->get_axis();
@@ -71,7 +73,7 @@ unsigned shmem_allocation::get_num_bytes(ir::value *x) {
return num_bytes;
}
void shmem_allocation::run(){
void allocation::run(){
using std::max;
using std::min;
typedef std::multimap<unsigned, segment> triples_map_type;
@@ -174,3 +176,5 @@ void shmem_allocation::run(){
}
}
}
}

View File

@@ -1,4 +1,4 @@
#include "triton/codegen/shmem_info.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -8,10 +8,11 @@
namespace triton {
namespace codegen{
namespace analysis{
namespace shmem{
// run pass on module
bool shmem_info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
bool info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
if(phi->get_parent() != terminator->get_parent())
return false;
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
@@ -23,7 +24,7 @@ bool shmem_info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
throw std::runtime_error("unreachable");
}
void shmem_info::replace(ir::value* before, ir::value *after) {
void info::replace(ir::value* before, ir::value *after) {
shared_.erase(before);
shared_.insert(after);
if(refs_.find(before) != refs_.end()){
@@ -70,7 +71,7 @@ void add_copy(ir::value *x, ir::builder &builder) {
}
}
void shmem_info::run(ir::module &mod) {
void info::run(ir::module &mod) {
// Add shared copies
for(ir::function *fn: mod.get_function_list()){
ir::builder builder(mod.get_context());
@@ -120,18 +121,20 @@ void shmem_info::run(ir::module &mod) {
}
// query double-buffered status
bool shmem_info::is_double(ir::value *x)
bool info::is_double(ir::value *x)
{ return double_.find(x) != double_.end(); }
// query shared status
bool shmem_info::is_shared(ir::value *x)
bool info::is_shared(ir::value *x)
{ return shared_.find(x) != shared_.end(); }
// get reference if any
ir::value *shmem_info::get_reference(ir::value *x)
ir::value *info::get_reference(ir::value *x)
{ return refs_[x]; }
}
}
}
}

View File

@@ -1,5 +1,5 @@
#include "triton/codegen/shmem_liveness.h"
#include "triton/codegen/shmem_info.h"
#include "triton/codegen/analysis/shmem/liveness.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
@@ -8,10 +8,11 @@
namespace triton{
namespace codegen{
namespace analysis{
namespace shmem{
// Entry point
void shmem_liveness::run(ir::module &mod) {
void liveness::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
// Assigns index to each instruction
slot_index index = 0;
@@ -39,3 +40,5 @@ void shmem_liveness::run(ir::module &mod) {
}
}
}
}

View File

@@ -1,4 +1,4 @@
#include "triton/codegen/tune.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "triton/ir/module.h"
@@ -12,6 +12,7 @@
namespace triton{
namespace codegen{
namespace analysis{
tune::tune(): num_global_ranges_(0){ }
@@ -257,7 +258,7 @@ void tune::run(ir::module &mod) {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
if(node.second == 2)
fpw->set_value(1);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}
}
@@ -270,7 +271,8 @@ void tune::run(ir::module &mod) {
continue;
auto shapes = i->get_type()->get_tile_shapes();
if(auto *x = dynamic_cast<ir::load_inst*>(i)){
if(auto *x = dynamic_cast<ir::load_inst*>(i))
if(fragments_.at({i, 0}) == STRIDED_SCAN){
ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty();
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
@@ -452,3 +454,4 @@ unsigned tune::get_num_threads() {
}
}
}

View File

@@ -1,15 +1,15 @@
#include "triton/codegen/selection.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/shmem_allocation.h"
#include "triton/codegen/target.h"
#include "triton/codegen/alignment_info.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "triton/codegen/selection/selection.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/codegen/analysis/shmem/allocation.h"
#include "triton/codegen/selection/target.h"
#include "triton/codegen/analysis/alignment.h"
#include "triton/ir/context.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@@ -485,7 +485,7 @@ inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw,
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
const auto& shapes = v->get_type()->get_tile_shapes();
size_t dim = shapes.size();
if(params_->get_fragment(v, 0) == tune::STRIDED_SCAN){
if(params_->get_fragment(v, 0) == analysis::tune::STRIDED_SCAN){
std::vector<unsigned> contiguous(dim);
std::vector<unsigned> block_size(dim);
std::vector<unsigned> warp_size(dim);
@@ -1022,7 +1022,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
{
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN) {
if(params_->get_fragment(ins, 0) == analysis::tune::STRIDED_SCAN) {
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
result->for_each([&](indices_t idx){
@@ -1047,8 +1047,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
a = builder.CreateFPCast(a, c_ty);
if(b->getType() != c_ty)
b = builder.CreateFPCast(b, c_ty);
// a = ConstantFP::get(builder.getFloatTy(), 1);
// b = ConstantFP::get(builder.getFloatTy(), 1);
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);
@@ -1060,13 +1058,14 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
TA->set_return_mode(true);
TB->set_return_mode(true);
std::map<Value*, std::vector<Value*>> fcs;
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
result->for_each([&](indices_t idx){
fcs[{builder.getInt32(0)}].push_back(TC->get_value(idx));
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
fcs[key].push_back(TC->get_value(idx));
});
Type *fp32_ty = builder.getFloatTy();
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
@@ -1122,8 +1121,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
std::swap(idx_a[0], idx_a[1]);
if(!dot->is_b_trans())
std::swap(idx_b[0], idx_b[1]);
// idx_a.push_back(builder.getInt32(0));
// idx_b.push_back(builder.getInt32(0));
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
Value *ha = TA->get_value(idx_a);
Value *hb = TB->get_value(idx_b);
for(unsigned ii = 0; ii < pack_size_0_; ii++)
@@ -1159,9 +1158,11 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// write back
unsigned i = 0;
result->for_each([&](indices_t idx){
if(i >= fcs.at({builder.getInt32(0)}).size())
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
if(i >= fcs.at(key).size())
i = 0;
result->set_value(idx, fcs.at({builder.getInt32(0)})[i++]);
result->set_value(idx, fcs.at(key)[i++]);
});
TA->set_return_mode(false);

View File

@@ -1,4 +1,4 @@
#include "triton/codegen/target.h"
#include "triton/codegen/selection/target.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Intrinsics.h"

View File

@@ -2,10 +2,11 @@
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/ir/cfg.h"
#include "triton/codegen/optimize_dce.h"
#include "triton/codegen/transform/dce.h"
namespace triton {
namespace codegen{
namespace transform{
void optimize_dce::run(ir::module &mod) {
@@ -60,3 +61,4 @@ void optimize_dce::run(ir::module &mod) {
}
}
}

View File

@@ -1,11 +1,12 @@
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/codegen/optimize_dot.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/transform/dot.h"
#include "triton/codegen/analysis/tune.h"
namespace triton {
namespace codegen{
namespace transform{
inline bool is_trans(ir::value *v){
auto *x = dynamic_cast<ir::trans_inst*>(v);
@@ -109,3 +110,4 @@ void optimize_dot::run(ir::module &mod) {
}
}
}

View File

@@ -1,15 +1,16 @@
#include <algorithm>
#include "triton/codegen/reassociate.h"
#include "triton/codegen/alignment_info.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/analysis/alignment.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/cfg.h"
#include "triton/codegen/tune.h"
namespace triton {
namespace codegen{
namespace transform{
//inline Constant *get_gep_cst_offset(GetElementPtrInst *gep){
// std::vector<Value*> idx_vals;
@@ -154,7 +155,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
return new_value;
}
reassociate::reassociate(tune* params, alignment_info* align)
reassociate::reassociate(analysis::tune* params, analysis::alignment_info* align)
: params_(params), align_(align)
{ }
@@ -280,3 +281,4 @@ void reassociate::run(ir::module &mod) {
}
}
}

View File

@@ -1,7 +1,7 @@
#include <algorithm>
#include "triton/codegen/shmem_barriers.h"
#include "triton/codegen/shmem_allocation.h"
#include "triton/codegen/shmem_info.h"
#include "triton/codegen/transform/shmem/barriers.h"
#include "triton/codegen/analysis/shmem/allocation.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -11,6 +11,7 @@
namespace triton {
namespace codegen{
namespace transform{
bool shmem_barriers::intersect(const interval_vec_t &X, interval_t x) {
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
@@ -137,3 +138,4 @@ void shmem_barriers::run(ir::module &mod) {
}
}
}

View File

@@ -1,9 +1,10 @@
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/codegen/optimize_trans.h"
#include "triton/codegen/transform/trans.h"
namespace triton {
namespace codegen{
namespace transform{
ir::value* optimize_trans::replace_phi(ir::value* value,
@@ -73,3 +74,4 @@ void optimize_trans::run(ir::module &mod) {
}
}
}

View File

@@ -1,5 +1,5 @@
#include "triton/codegen/vectorize.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/transform/vectorize.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -8,6 +8,7 @@
namespace triton {
namespace codegen{
namespace transform{
void vectorize::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
@@ -39,3 +40,4 @@ void vectorize::run(ir::module &mod) {
}
}
}

View File

@@ -62,7 +62,8 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
jit->add_module(name_.c_str(), src.c_str(), best.params);
}
else{
params_t params = heuristics();
// params_t params = heuristics();
params_t params = {4, 2, 16, 4, 4, 16, 2, 2, 1, 1, 1, 8, 64, 8, 8, 1, 4, 2, 1};
jit->add_module(name_.c_str(), src.c_str(), params);
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());

View File

@@ -96,7 +96,7 @@ void dot::triton_c_src(std::ostream &os) const {
restrict read_only align(16) )" + ab_ty_ + R"( *B,
)" + c_ty_ + R"(* C,
int lda, int ldc, int N,
int* lut, int* locks, int nlocks){
int* lut, int* locks, int nlocks) {
int ridx = get_range_id(0);
int ridy = get_range_id(1);
float acc[TM, TN] = 0;
@@ -129,10 +129,10 @@ void dot::triton_c_src(std::ostream &os) const {
)" + c_ty_ + R"(" c[TM, TN] = acc;
)" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
bool checkc[TM, TN] = (rxc < N)[:, newaxis];
if(lockid == 0)
if(lockid == 0) {
@checkc *pc = c;
else
{
}
else {
int *plock = locks + ridx*nlocks + lockid - 1;
int *pcount = plock + get_num_program(0)*nlocks;
while(__atomic_cas(plock, 0, 1));
@@ -147,10 +147,11 @@ void dot::triton_c_src(std::ostream &os) const {
__atomic_exch(plock, 0);
}
})";
os << result;
}
}
}
}

View File

@@ -75,14 +75,14 @@ void dot::triton_c_src(std::ostream &os) const {
std::string ZS = "1";
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string XAS0 = "TM", XAS1 = "TK", XAS2 = ZS;
std::string XBS0 = "TK", XBS1 = ZS, XBS2 = "TN";
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT_ ? "trans(xa, 2, 0, 1)" : "xa";
std::string useb = BT_ ? "trans(xb, 1, 0, 2)" : "trans(xb, 0, 2, 1)";
std::string usea = AT_ ? "trans(a)" : "a";
std::string useb = BT_ ? "trans(b)" : "b";
if(AT_){
std::swap(AS0, AS1);
std::swap(XAS0, XAS1);
@@ -99,15 +99,15 @@ void dot::triton_c_src(std::ostream &os) const {
}
std::string AS = AS0 + ", " + AS1;
std::string BS = BS0 + ", " + BS1;
std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
std::string XCS = "TM, TN, " + ZS;
// std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
// std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
std::string XCS = "TM, TN";
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res =
R"(
const tunable int TM = {16, 32, 64, 128};
const tunable int TN = {16, 32, 64, 128};
const tunable int TM = {32};
const tunable int TN = {32};
const tunable int TK = {32};
const tunable int GZ = {1};
@@ -131,8 +131,6 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
bool checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
)" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"();
)" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"();
for(int k = K; k > 0; k = k - TK){
xc = dot()" + usea + ", " + useb + R"(, xc);
pa = pa + TK)" + lda0 + R"(;
@@ -141,13 +139,11 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
bool checkb[)" + BS + R"(] = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
xa = __reshape(a, )" + XAS + R"();
xb = __reshape(b, )" + XBS + R"();
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN);
)" + c_ty_ + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
)" + c_ty_ + R"( c[TM, TN] = __sum(xc, 2);
)" + c_ty_ + R"( c[TM, TN] = xc;
bool checkc0[TM] = rxc < M;
bool checkc1[TN] = ryc < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];

View File

@@ -28,7 +28,7 @@
#include "triton/driver/helpers/CL/infos.hpp"
#include "triton/driver/device.h"
#include "triton/driver/context.h"
#include "triton/codegen/target.h"
#include "triton/codegen/selection/target.h"
namespace triton
{

View File

@@ -1,6 +1,6 @@
#include <string>
#include "triton/lang/lang.h"
#include "triton/codegen/target.h"
#include "triton/codegen/selection/target.h"
#include "triton/ir/context.h"
#include "triton/ir/context_impl.h"
#include "triton/driver/device.h"