[GENERAL] Merged v1.0alpha into master. Added features are:

- A100 support via mma.16816
- Thread swizzling for conflict-free shared memory accesses without
padding
- Complete overhaul of the LLVM code generation in
codegen/selection/generator.cc to remove overengineering
- Added debugging capabilities in the Python binding
- Compilation error for kernels that spill
This commit is contained in:
Philippe Tillet
2021-01-11 19:20:34 -05:00
parent c0bc7ed8b0
commit 083bbd1e8d
75 changed files with 2688 additions and 4512 deletions

View File

@@ -27,7 +27,7 @@ private:
void update_graph_trans(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i, bool connect_ret=true);
void update_graph_no_edge(ir::instruction *i);
void update_graph(ir::instruction *i);

View File

@@ -25,7 +25,7 @@ class axes;
class align;
class layout_visitor;
class data_layout;
class mma884_layout;
class mma_layout;
class scanline_layout;
class shared_layout;
@@ -33,7 +33,7 @@ class shared_layout;
class layout_visitor {
public:
virtual void visit_layout(data_layout *);
virtual void visit_layout_hmma_884(mma884_layout*) = 0;
virtual void visit_layout_mma(mma_layout*) = 0;
virtual void visit_layout_scanline(scanline_layout*) = 0;
virtual void visit_layout_shared(shared_layout*) = 0;
};
@@ -41,7 +41,7 @@ public:
class data_layout {
protected:
enum id_t {
HMMA_884,
MMA,
SCANLINE,
SHARED
};
@@ -68,7 +68,7 @@ public:
// visitor
virtual void accept(layout_visitor* vst) = 0;
// downcast
mma884_layout* to_mma884() { return downcast<mma884_layout>(HMMA_884); }
mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
// accessors
@@ -77,9 +77,10 @@ public:
const order_t& get_order() const { return order_; }
const values_t& get_values() const { return values_;}
int get_axis(size_t k) const { return axes_.at(k); }
std::vector<int> get_axes() const { return axes_; }
const int get_order(size_t k) const { return order_.at(k); }
// find the position of given axis
size_t find_axis(int to_find) const;
int find_axis(int to_find) const;
private:
@@ -92,21 +93,29 @@ protected:
shape_t shape_;
};
class mma884_layout: public data_layout {
class mma_layout: public data_layout {
public:
mma884_layout(size_t num_warps,
mma_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
analysis::align* align, target *tgt,
shared_layout* layout_a,
shared_layout* layout_b);
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
// accessor
int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); }
int spw(size_t k) { return spw_.at(k); }
int spt(size_t k) { return spt_.at(k); }
int rep(size_t k) { return rep_.at(k); }
private:
std::vector<int> fpw_;
std::vector<int> spw_;
std::vector<int> wpt_;
std::vector<int> spt_;
std::vector<int> rep_;
};
struct scanline_layout: public data_layout {
@@ -138,7 +147,7 @@ private:
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
public:
shared_layout(const data_layout *arg,
shared_layout(data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
@@ -149,11 +158,22 @@ public:
size_t get_size() { return size_; }
ir::type* get_type() { return ty_; }
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
size_t get_num_per_phase() { return num_per_phase_; }
ir::value* hmma_dot_a() { return hmma_dot_a_; }
ir::value* hmma_dot_b() { return hmma_dot_b_; }
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
int get_mma_vec() { return mma_vec_;}
data_layout* get_arg_layout() { return arg_layout_; }
private:
size_t size_;
ir::type *ty_;
std::shared_ptr<double_buffer_info_t> double_buffer_;
size_t num_per_phase_;
ir::value* hmma_dot_a_;
ir::value* hmma_dot_b_;
data_layout* arg_layout_;
int mma_vec_;
};

View File

@@ -0,0 +1,43 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#include <map>
namespace triton{
namespace ir{
class module;
}
namespace codegen{
class target;
namespace analysis{
class layouts;
class data_layout;
class swizzle {
public:
// constructor
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
// accessors
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
int get_vec (data_layout* layout) { return vec_.at(layout); }
// run
void run(ir::module &mod);
private:
layouts* layouts_;
target* tgt_;
std::map<data_layout*, int> per_phase_;
std::map<data_layout*, int> max_phase_;
std::map<data_layout*, int> vec_;
};
}
}
}
#endif

View File

@@ -5,13 +5,14 @@
#include "triton/ir/visitor.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/selection/machine_value.h"
#include <functional>
// forward
namespace llvm{
class Type;
class Value;
class BasicBlock;
class Attribute;
class Instruction;
class Constant;
class LLVMContext;
@@ -25,6 +26,13 @@ namespace llvm{
}
namespace triton{
namespace ir{
class attribute;
class load_inst;
class store_inst;
}
namespace codegen{
// forward
@@ -36,6 +44,7 @@ class allocation;
class cts;
class axes;
class layouts;
class swizzle;
}
// typedef
typedef llvm::IRBuilder<llvm::ConstantFolder,
@@ -43,17 +52,14 @@ typedef llvm::IRBuilder<llvm::ConstantFolder,
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Attribute Attribute;
typedef llvm::BasicBlock BasicBlock;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
typedef std::vector<Value*> indices_t;
// forward
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;
class target;
}
@@ -62,110 +68,129 @@ class target;
namespace triton{
namespace codegen{
struct distributed_axis {
int contiguous;
std::vector<Value*> values;
Value* thread_id;
};
class generator: public ir::visitor, public analysis::layout_visitor {
private:
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
Value* get_value(ir::value *x, const indices_t& idx);
void set_value(ir::value *x, const indices_t& idx, Value* v);
void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add);
void init_idx(ir::value *x);
Instruction* add_barrier();
Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
void finalize_shared_layout(analysis::shared_layout*);
void finalize_function(ir::function*);
void finalize_phi_node(ir::phi_node*);
private:
Type *cvt(ir::type *ty);
llvm::Attribute cvt(ir::attribute attr);
public:
generator(analysis::axes *a_axes,
analysis::layouts *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
analysis::swizzle *swizzle,
target *tgt,
unsigned num_warps);
void visit_value(ir::value* v);
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*);
void visit_icmp_inst(ir::icmp_inst*);
void visit_fcmp_inst(ir::fcmp_inst*);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
void visit_load_inst(ir::load_inst*);
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
void visit_masked_load_inst(ir::masked_load_inst*);
void visit_store_inst(ir::store_inst*);
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*);
void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*);
void visit_downcast_inst(ir::downcast_inst*);
void visit_exp_inst(ir::exp_inst*);
void visit_log_inst(ir::log_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_program_inst(ir::get_num_program_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
void visit_atomic_add_inst(ir::atomic_add_inst*);
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_dot_inst(ir::dot_inst*);
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_recoalesce_inst(ir::recoalesce_inst*);
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*);
void visit_async_wait_inst(ir::async_wait_inst*);
void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
void visit_constant_fp(ir::constant_fp*);
void visit_alloc_const(ir::alloc_const*);
void visit_function(ir::function*);
void visit_basic_block(ir::basic_block*);
void visit_argument(ir::argument*);
void visit(ir::module &, llvm::Module &);
void visit_layout_hmma_884(analysis::mma884_layout*);
// layouts
void visit_layout_mma(analysis::mma_layout*);
void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*);
void visit(ir::module &, llvm::Module &);
private:
LLVMContext *ctx_;
Builder* builder_;
Module *mod_;
std::map<const analysis::data_layout*, machine_data_layout*> machine_layouts_;
analysis::axes *a_axes_;
analysis::swizzle *swizzle_;
std::map<unsigned, distributed_axis> axes_;
std::map<ir::value *, Value *> vmap_;
std::map<ir::value *, tile *> tmap_;
target *tgt_;
analysis::layouts *layouts_;
analysis::align *alignment_;
analysis::allocation *alloc_;
Value *sh_mem_ptr_;
Value *shmem_;
unsigned num_warps_;
std::set<ir::value*> seen_;
std::map<analysis::data_layout*, Value*> offset_a_m_;
std::map<analysis::data_layout*, Value*> offset_a_k_;
std::map<analysis::data_layout*, Value*> offset_b_k_;
std::map<analysis::data_layout*, Value*> offset_b_n_;
std::map<analysis::data_layout*, Value*> shared_ptr_;
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
std::map<analysis::data_layout*, Value*> shared_off_;
std::map<ir::value*, Value*> shmems_;
std::map<ir::value*, Value*> shoffs_;
std::map<ir::value*, std::vector<indices_t>> idxs_;
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;
};
}

View File

@@ -1,138 +0,0 @@
#pragma once
#ifndef _TRITON_SELECTION_MACHINE_LAYOUT_H_
#define _TRITON_SELECTION_MACHINE_LAYOUT_H_
#include <map>
#include "triton/codegen/analysis/layout.h"
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace ir{
class value;
}
namespace codegen{
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layouts;
}
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
class distributed_axis;
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;
class target;
}
}
namespace triton{
namespace codegen{
class machine_data_layout {
public:
virtual tile* create(ir::value *v) = 0;
};
class machine_shared_layout: public machine_data_layout {
public:
machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr,
analysis::shared_layout* layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap);
tile* create(ir::value *v);
Module *mod_;
Builder *builder_;
target *tgt_;
analysis::allocation* alloc_;
Value *&sh_mem_ptr_;
analysis::shared_layout* layout_;
std::map<ir::value *, Value *>& vmap_;
std::map<ir::value *, tile *>& tmap_;
Value *offset_;
Value *ptr_;
Value *pre_ptr_;
Value *next_ptr_;
};
class machine_distributed_layout: public machine_data_layout {
public:
machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::data_layout* layout);
tile* create(ir::value *v);
Module *mod_;
Builder *builder_;
target *tgt_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis>& axes_;
analysis::data_layout* layout_;
};
class machine_mma884_layout: public machine_distributed_layout {
public:
machine_mma884_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::mma884_layout* layout);
Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_;
unsigned pack_size_0_;
unsigned pack_size_1_;
unsigned num_packs_0_;
unsigned num_packs_1_;
};
class machine_scanline_layout: public machine_distributed_layout {
public:
machine_scanline_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::scanline_layout* layout);
};
}
}
#endif

View File

@@ -1,152 +0,0 @@
#pragma once
#ifndef _TRITON_SELECTION_MACHINE_VALUE_H_
#define _TRITON_SELECTION_MACHINE_VALUE_H_
#include <vector>
#include <map>
#include <functional>
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace codegen{
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
}
}
namespace triton{
namespace codegen{
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layouts;
}
class distributed_axis;
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;
class target;
typedef std::vector<Value*> indices_t;
}
}
namespace triton{
namespace codegen{
struct distributed_axis {
int contiguous;
std::vector<Value*> values;
Value* thread_id;
};
class tile {
protected:
typedef std::vector<unsigned> shapes_t;
public:
tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
virtual void set_value(indices_t idx, Value *v) = 0;
virtual Value* get_value(indices_t idx) = 0;
Type *get_ty() const { return ty_; }
shapes_t get_shapes() const { return shapes_; }
protected:
Type *ty_;
shapes_t shapes_;
};
class shared_tile: public tile {
private:
void extract_constant(Value *arg, Value *&non_cst, Value *&cst);
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
public:
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector<int>& perm = {});
void set_vector_size(unsigned vector_size);
void set_return_mode(bool return_vector);
void set_value(indices_t, Value *);
Value* get_ptr_to(indices_t idx);
Value* get_value(indices_t idx);
Value* get_pointer() { return ptr_; }
Value* get_offset() { return offset_; }
const std::vector<int>& get_perm() { return perm_; }
const std::vector<int>& get_order() { return order_; }
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx);
private:
Value *ptr_;
bool return_vector_;
Builder &builder_;
Value *offset_;
std::map<indices_t, Value*> ptr_cache_;
unsigned vector_size_;
std::vector<int> order_;
std::vector<int> perm_;
};
// Distribtued tile
class distributed_tile: public tile{
typedef std::vector<distributed_axis> axes_t;
typedef std::vector<indices_t> ordered_indices_vec_t;
typedef std::map<indices_t, unsigned> indices_map_t;
typedef std::map<indices_t, Value*> values_map_t;
private:
void init_indices();
public:
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder);
void set_value(indices_t idx, Value *v);
Value* get_value(indices_t idx);
const std::vector<int>& get_order() { return order_; }
unsigned get_linear_index(indices_t idx);
indices_t get_ordered_indices(unsigned id);
void for_each(std::function<void(indices_t)> fn, int start = 0, int end = -1);
void for_each(std::function<void(indices_t)> fn, std::vector<int> start, std::vector<int> size);
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
private:
axes_t axes_;
std::vector<int> order_;
indices_map_t indices_;
values_map_t values_;
ordered_indices_vec_t ordered_indices_;
Builder &builder_;
};
}
}
#endif

View File

@@ -35,6 +35,8 @@ namespace codegen{
namespace triton{
namespace codegen{
class nvidia_cu_target;
class target {
public:
target(bool is_gpu): is_gpu_(is_gpu){}
@@ -47,6 +49,7 @@ public:
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
virtual unsigned guaranteed_alignment() = 0;
nvidia_cu_target* as_nvidia();
bool is_gpu() const;
private:
@@ -68,7 +71,7 @@ public:
class nvidia_cu_target: public target {
public:
nvidia_cu_target(): target(true){}
nvidia_cu_target(int sm): target(true), sm_(sm){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
@@ -76,7 +79,11 @@ public:
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
int sm() { return sm_; }
unsigned guaranteed_alignment() { return 16; }
private:
int sm_;
};
class cpu_target: public target {

View File

@@ -11,14 +11,22 @@ namespace ir {
class value;
class phi_node;
class instruction;
class builder;
}
namespace codegen{
namespace transform{
class cts {
private:
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
public:
cts(bool use_async = false): use_async_(use_async) {}
void run(ir::module &mod);
private:
bool use_async_;
};
}

View File

@@ -1,6 +1,8 @@
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
#include <vector>
namespace triton {
namespace ir {
@@ -31,14 +33,14 @@ private:
private:
interval_vec_t join(const std::vector<interval_vec_t>& intervals);
void insert_barrier(ir::instruction *instr, ir::builder &builder);
void insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder);
bool intersect(const interval_vec_t &X, interval_t x);
bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
void add_reference(ir::value *v, interval_vec_t &res);
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
std::map<triton::ir::instruction *, std::pair<bool, bool> > &insert_loc, std::set<triton::ir::value *> &safe_war, std::vector<triton::ir::instruction *> &to_sync);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):

View File

@@ -1,6 +1,7 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#include "triton/codegen/target.h"
namespace triton {
@@ -27,12 +28,16 @@ private:
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
private:
public:
peephole() {}
peephole(target* tgt): tgt_(tgt) {}
void run(ir::module &mod);
private:
target* tgt_;
};

View File

@@ -0,0 +1,26 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
namespace triton {
// forward declaration
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class reorder {
public:
void run(ir::module& module);
};
}
}
}
#endif

View File

@@ -39,43 +39,23 @@ public:
// CUDA device
class cu_device: public device {
public:
//Supported architectures
enum class Architecture{
//NVidia
SM_2_0,
SM_2_1,
SM_3_0,
SM_3_5,
SM_3_7,
SM_5_0,
SM_5_2,
SM_6_0,
SM_6_1,
SM_7_0,
UNKNOWN
};
private:
//Metaprogramming elper to get cuda info from attribute
template<CUdevice_attribute attr>
int cuGetInfo() const;
inline Architecture nv_arch(std::pair<unsigned int, unsigned int> sm) const;
inline nvmlDevice_t nvml_device() const;
public:
cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){}
// Accessors
Architecture architecture() const;
// Informations
std::string infos() const;
size_t address_bits() const;
std::vector<size_t> max_block_dim() const;
size_t warp_size() const;
// Compute Capability
void interpret_as(std::pair<size_t, size_t> cc);
std::pair<size_t, size_t> compute_capability() const;
void interpret_as(int cc);
int compute_capability() const;
// Identifier
std::string name() const;
std::string pci_bus_id() const;
@@ -91,7 +71,7 @@ public:
std::unique_ptr<codegen::target> make_target() const;
private:
std::shared_ptr<std::pair<size_t, size_t>> interpreted_as_;
std::shared_ptr<int> interpreted_as_;
};
}

View File

@@ -19,18 +19,18 @@ namespace triton
namespace nvrtc
{
#define ISAAC_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
ISAAC_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
ISAAC_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
ISAAC_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
ISAAC_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
ISAAC_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
#undef ISAAC_CREATE_NVRTC_EXCEPTION
#undef TRITON_CREATE_NVRTC_EXCEPTION
}
@@ -38,107 +38,107 @@ namespace triton
{
class base: public std::exception{};
#define ISAAC_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
ISAAC_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
ISAAC_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
ISAAC_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
ISAAC_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
ISAAC_CREATE_CUDA_EXCEPTION(no_device ,"no device");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
ISAAC_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
ISAAC_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
ISAAC_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
ISAAC_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
ISAAC_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
ISAAC_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
ISAAC_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
ISAAC_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
ISAAC_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
ISAAC_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
ISAAC_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
ISAAC_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
ISAAC_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
ISAAC_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
ISAAC_CREATE_CUDA_EXCEPTION(not_found ,"not found");
ISAAC_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
ISAAC_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
ISAAC_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
ISAAC_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
ISAAC_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
ISAAC_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
ISAAC_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
ISAAC_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
ISAAC_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
ISAAC_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
ISAAC_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
ISAAC_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
ISAAC_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
ISAAC_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
ISAAC_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
ISAAC_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
ISAAC_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
ISAAC_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found");
TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
#undef ISAAC_CREATE_CUDA_EXCEPTION
#undef TRITON_CREATE_CUDA_EXCEPTION
}
namespace cublas
{
class base: public std::exception{};
#define ISAAC_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
ISAAC_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
ISAAC_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
ISAAC_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
ISAAC_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
ISAAC_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
ISAAC_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
ISAAC_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
ISAAC_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
ISAAC_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
ISAAC_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
#undef ISAAC_CREATE_CUBLAS_EXCEPTION
#undef TRITON_CREATE_CUBLAS_EXCEPTION
}
namespace cudnn
{
#define ISAAC_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
ISAAC_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
ISAAC_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
ISAAC_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
ISAAC_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
ISAAC_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
ISAAC_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
ISAAC_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
ISAAC_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
ISAAC_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
ISAAC_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
}
}

View File

@@ -44,6 +44,13 @@ public:
const std::string &features,
file_type_t file_type);
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
std::string llir() const { return llir_; }
int spilled() const { return spilled_; }
private:
std::string llir_;
protected:
int spilled_;
};
// CPU
@@ -59,12 +66,12 @@ class cu_module: public module {
public:
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
cu_module(const std::string& source);
cu_module(driver::device* device, const std::string& source);
std::unique_ptr<buffer> symbol(const char * name) const;
const std::string& source() const { return source_; }
const std::string& ptx() const { return ptx_; }
private:
std::string source_;
std::string ptx_;
};

View File

@@ -146,8 +146,10 @@ public:
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
// Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = "");
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_copy_from_shared(value *arg, const std::string &name = "");
value *create_barrier(const std::string &name = "");
value *create_async_wait();
private:
context &ctx_;

View File

@@ -7,7 +7,7 @@ namespace triton{
namespace ir{
enum binary_op_t {
enum binary_op_t: unsigned int{
Add,
FAdd,
Sub,
@@ -28,7 +28,7 @@ enum binary_op_t {
Xor
};
enum cast_op_t {
enum cast_op_t: unsigned int {
Trunc,
ZExt,
SExt,
@@ -44,7 +44,7 @@ enum cast_op_t {
AddrSpaceCast
};
enum cmp_pred_t {
enum cmp_pred_t: unsigned int {
FIRST_FCMP_PREDICATE,
FCMP_FALSE,
FCMP_OEQ,
@@ -113,6 +113,7 @@ enum value_id_t: unsigned {
// io
INST_UNMASKED_LOAD,
INST_MASKED_LOAD,
INST_MASKED_LOAD_ASYNC,
INST_UNMASKED_STORE,
INST_MASKED_STORE,
// retile
@@ -139,6 +140,7 @@ enum value_id_t: unsigned {
INST_COPY_FROM_SHARED,
INST_RECOALESCE,
INST_BARRIER,
INST_ASYNC_WAIT,
INST_MAKE_RANGE_DYN,
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE

View File

@@ -72,6 +72,7 @@ public:
case noalias: return ".noalias";
case aligned: return ".aligned(" + std::to_string(value_) + ")";
case multiple_of: return ".readonly";
case retune: return ".retunr";
default: break;
}
assert(false);

View File

@@ -64,9 +64,10 @@ public:
// cloning
ir::instruction* clone() {
ir::instruction* res = clone_impl();
for(auto it = op_begin(); it != op_end(); it++)
(*it)->add_use(res);
// for(auto it = op_begin(); it != op_end(); it++)
// (*it)->add_use(res);
res->parent_ = nullptr;
res->users_.clear();
return res;
}
// instruction id
@@ -431,6 +432,25 @@ public:
_TRITON_DEFINE_ACCEPT(masked_load_inst)
};
// masked load async
class masked_load_async_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load_async_async"; }
masked_load_async_inst(value *ptr, value *mask, value *false_value,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(1); }
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_async_inst)
_TRITON_DEFINE_ACCEPT(masked_load_async_inst)
};
class atomic_add_inst: public io_inst {
private:
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
@@ -757,6 +777,7 @@ public:
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
};
class recoalesce_inst: public unary_inst{
private:
using unary_inst::unary_inst;
@@ -780,6 +801,18 @@ public:
instruction *next = nullptr);
};
class async_wait_inst: public instruction{
private:
async_wait_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait"; }
_TRITON_DEFINE_CLONE(async_wait_inst)
_TRITON_DEFINE_ACCEPT(async_wait_inst)
public:
static async_wait_inst* create(context &ctx, const std::string &name = "",
instruction *next = nullptr);
};
// On NVIDIA, implementation is such that
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
// so as to enable re-association on nv_static_program_idx which is constant

View File

@@ -65,7 +65,9 @@ class select_inst;
class recoalesce_inst;
class copy_to_shared_inst;
class copy_from_shared_inst;
class masked_load_async_inst;
class barrier_inst;
class async_wait_inst;
class make_range_dyn;
class make_range;
@@ -139,7 +141,9 @@ public:
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0;
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
virtual void visit_barrier_inst(barrier_inst*) = 0;
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
virtual void visit_make_range_dyn(make_range_dyn*) = 0;
virtual void visit_make_range(make_range*) = 0;

View File

@@ -0,0 +1,34 @@
#pragma once
#ifndef _TRITON_RUNTIME_ERROR_H_
#define _TRITON_RUNTIME_ERROR_H_
#include <exception>
#include <string>
namespace triton {
namespace runtime{
namespace exception {
class base: public std::exception {};
#define TRITON_CREATE_RUNTIME_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "Triton: Error - Runtime: " msg; } };
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_shared_memory, "out of shared memory")
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_registers, "out of registers")
class no_valid_configuration: public exception::base {
public:
no_valid_configuration(const std::string& err): err_(err) { }
const char * what() const throw(){ return err_.c_str(); }
private:
std::string err_;
};
#undef TRITON_CREATE_RUNTIME_EXCEPTION
}
}
}
#endif

View File

@@ -6,6 +6,7 @@
#include <map>
#include <vector>
#include <string>
#include <sstream>
#include <memory>
#include <functional>
#include <set>
@@ -13,6 +14,7 @@
#include "triton/ir/context.h"
#include "triton/codegen/target.h"
#include "triton/runtime/arg.h"
#include "triton/runtime/error.h"
namespace llvm {
class Module;
@@ -56,33 +58,43 @@ template<typename T> inline T convert(const std::string& name);
template<> inline long convert<long>(const std::string& name) { return std::stol(name); }
template<> inline int convert<int>(const std::string& name) { return std::stoi(name); }
template<class T>
void add_arg(std::stringstream& ss, T arg) {
ss.write((char*)&arg, sizeof(T));
}
enum asm_mode_t {
ASM_LLIR,
ASM_NV_PTX,
ASM_NV_SASS
};
struct options_space_t {
typedef std::pair<std::string, std::vector<std::string>> define_t;
std::vector<define_t> defines;
std::vector<int> num_warps;
std::vector<int> recompile_key;
};
struct options_t {
template<class T>
T D(const std::string& name) const {
return convert<T>(defines.at(name));
}
bool operator<(const options_t& other) const {
return std::make_pair(defines, num_warps) <
std::make_pair(other.defines, other.num_warps);
}
std::string to_str() const;
std::map<std::string, std::string> defines;
size_t num_warps;
};
class function {
public:
struct options_space_t {
typedef std::pair<std::string, std::vector<std::string>> define_t;
std::vector<define_t> defines;
std::vector<int> num_warps;
std::vector<int> recompile_key;
};
struct options_t {
template<class T>
T D(const std::string& name) const {
return convert<T>(defines.at(name));
}
bool operator<(const options_t& other) const {
return std::make_pair(defines, num_warps) <
std::make_pair(other.defines, other.num_warps);
}
std::string to_str() const;
std::map<std::string, std::string> defines;
size_t num_warps;
};
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
private:
class caller {
public:
@@ -135,7 +147,7 @@ public:
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device);
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device);
void set_cst(const char* name, void* data, size_t n_bytes);
std::string ptx(driver::device *device, const options_t& opt);
std::string get_asm(asm_mode_t mode, driver::device *device, const options_t& opt);
private:
std::map<std::string, std::vector<char>> cst_;

View File

@@ -33,25 +33,20 @@ private:
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
{
// const driver::device * device = stream->context()->device();
size_t warmup = 10;
size_t repeat = 50;
timer tmr;
std::vector<size_t> times;
double total_time = 0;
op();
for(size_t i = 0; i < warmup; i++)
op();
stream->synchronize();
tmr.start();
for(size_t i = 0; i < 10; i++){
// while(total_time*1e-9 < 1e-2){
// float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning
// if(normalize)
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
for(size_t i = 0; i < repeat; i++){
op();
// times.push_back(norm*tmr.get().count());
// total_time+=times.back();
}
stream->synchronize();
return (float)tmr.get().count() / 10;
return (float)tmr.get().count() / repeat;
// return *std::min_element(times.begin(), times.end());
}