[GENERAL] Merged v1.0alpha into master. Added features are:

- A100 support via mma.16816
- Thread swizzling for conflict-free shared memory accesses without
padding
- Complete overhaul of the LLVM code generation in
codegen/selection/generator.cc to remove overengineering
- Added debugging capabilities in the Python binding
- Compilation error for kernels that spill
This commit is contained in:
Philippe Tillet
2021-01-11 19:20:34 -05:00
parent c0bc7ed8b0
commit 083bbd1e8d
75 changed files with 2688 additions and 4512 deletions

View File

@@ -28,7 +28,8 @@
# We also want an user-specified LLVM_ROOT_DIR to take precedence over the # We also want an user-specified LLVM_ROOT_DIR to take precedence over the
# system default locations such as /usr/local/bin. Executing find_program() # system default locations such as /usr/local/bin. Executing find_program()
# multiples times is the approach recommended in the docs. # multiples times is the approach recommended in the docs.
set(llvm_config_names llvm-config-10 llvm-config-10.0 llvm-config100 set(llvm_config_names llvm-config-11 llvm-config-11.0
llvm-config-10 llvm-config-10.0 llvm-config100
llvm-config-9 llvm-config-9.0 llvm-config90 llvm-config-9 llvm-config-9.0 llvm-config90
llvm-config-8 llvm-config-8.0 llvm-config80 llvm-config-8 llvm-config-8.0 llvm-config80
llvm-config) llvm-config)

View File

@@ -27,7 +27,7 @@ private:
void update_graph_trans(ir::instruction *i); void update_graph_trans(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i); void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i); void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i); void update_graph_elementwise(ir::instruction *i, bool connect_ret=true);
void update_graph_no_edge(ir::instruction *i); void update_graph_no_edge(ir::instruction *i);
void update_graph(ir::instruction *i); void update_graph(ir::instruction *i);

View File

@@ -25,7 +25,7 @@ class axes;
class align; class align;
class layout_visitor; class layout_visitor;
class data_layout; class data_layout;
class mma884_layout; class mma_layout;
class scanline_layout; class scanline_layout;
class shared_layout; class shared_layout;
@@ -33,7 +33,7 @@ class shared_layout;
class layout_visitor { class layout_visitor {
public: public:
virtual void visit_layout(data_layout *); virtual void visit_layout(data_layout *);
virtual void visit_layout_hmma_884(mma884_layout*) = 0; virtual void visit_layout_mma(mma_layout*) = 0;
virtual void visit_layout_scanline(scanline_layout*) = 0; virtual void visit_layout_scanline(scanline_layout*) = 0;
virtual void visit_layout_shared(shared_layout*) = 0; virtual void visit_layout_shared(shared_layout*) = 0;
}; };
@@ -41,7 +41,7 @@ public:
class data_layout { class data_layout {
protected: protected:
enum id_t { enum id_t {
HMMA_884, MMA,
SCANLINE, SCANLINE,
SHARED SHARED
}; };
@@ -68,7 +68,7 @@ public:
// visitor // visitor
virtual void accept(layout_visitor* vst) = 0; virtual void accept(layout_visitor* vst) = 0;
// downcast // downcast
mma884_layout* to_mma884() { return downcast<mma884_layout>(HMMA_884); } mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); } scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); } shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
// accessors // accessors
@@ -77,9 +77,10 @@ public:
const order_t& get_order() const { return order_; } const order_t& get_order() const { return order_; }
const values_t& get_values() const { return values_;} const values_t& get_values() const { return values_;}
int get_axis(size_t k) const { return axes_.at(k); } int get_axis(size_t k) const { return axes_.at(k); }
std::vector<int> get_axes() const { return axes_; }
const int get_order(size_t k) const { return order_.at(k); } const int get_order(size_t k) const { return order_.at(k); }
// find the position of given axis // find the position of given axis
size_t find_axis(int to_find) const; int find_axis(int to_find) const;
private: private:
@@ -92,21 +93,29 @@ protected:
shape_t shape_; shape_t shape_;
}; };
class mma884_layout: public data_layout { class mma_layout: public data_layout {
public: public:
mma884_layout(size_t num_warps, mma_layout(size_t num_warps,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<unsigned>& shapes, const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
analysis::align* align); analysis::align* align, target *tgt,
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); } shared_layout* layout_a,
shared_layout* layout_b);
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
// accessor // accessor
int fpw(size_t k) { return fpw_.at(k); } int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); } int wpt(size_t k) { return wpt_.at(k); }
int spw(size_t k) { return spw_.at(k); }
int spt(size_t k) { return spt_.at(k); }
int rep(size_t k) { return rep_.at(k); }
private: private:
std::vector<int> fpw_; std::vector<int> fpw_;
std::vector<int> spw_;
std::vector<int> wpt_; std::vector<int> wpt_;
std::vector<int> spt_;
std::vector<int> rep_;
}; };
struct scanline_layout: public data_layout { struct scanline_layout: public data_layout {
@@ -138,7 +147,7 @@ private:
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res); static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
public: public:
shared_layout(const data_layout *arg, shared_layout(data_layout *arg,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<unsigned>& shapes, const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_, const std::vector<ir::value *> &values_,
@@ -149,11 +158,22 @@ public:
size_t get_size() { return size_; } size_t get_size() { return size_; }
ir::type* get_type() { return ty_; } ir::type* get_type() { return ty_; }
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); } double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
size_t get_num_per_phase() { return num_per_phase_; }
ir::value* hmma_dot_a() { return hmma_dot_a_; }
ir::value* hmma_dot_b() { return hmma_dot_b_; }
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
int get_mma_vec() { return mma_vec_;}
data_layout* get_arg_layout() { return arg_layout_; }
private: private:
size_t size_; size_t size_;
ir::type *ty_; ir::type *ty_;
std::shared_ptr<double_buffer_info_t> double_buffer_; std::shared_ptr<double_buffer_info_t> double_buffer_;
size_t num_per_phase_;
ir::value* hmma_dot_a_;
ir::value* hmma_dot_b_;
data_layout* arg_layout_;
int mma_vec_;
}; };

View File

@@ -0,0 +1,43 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#include <map>
namespace triton{
namespace ir{
class module;
}
namespace codegen{
class target;
namespace analysis{
class layouts;
class data_layout;
class swizzle {
public:
// constructor
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
// accessors
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
int get_vec (data_layout* layout) { return vec_.at(layout); }
// run
void run(ir::module &mod);
private:
layouts* layouts_;
target* tgt_;
std::map<data_layout*, int> per_phase_;
std::map<data_layout*, int> max_phase_;
std::map<data_layout*, int> vec_;
};
}
}
}
#endif

View File

@@ -5,13 +5,14 @@
#include "triton/ir/visitor.h" #include "triton/ir/visitor.h"
#include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/layout.h"
#include "triton/codegen/selection/machine_value.h"
#include <functional> #include <functional>
// forward // forward
namespace llvm{ namespace llvm{
class Type; class Type;
class Value; class Value;
class BasicBlock;
class Attribute;
class Instruction; class Instruction;
class Constant; class Constant;
class LLVMContext; class LLVMContext;
@@ -25,6 +26,13 @@ namespace llvm{
} }
namespace triton{ namespace triton{
namespace ir{
class attribute;
class load_inst;
class store_inst;
}
namespace codegen{ namespace codegen{
// forward // forward
@@ -36,6 +44,7 @@ class allocation;
class cts; class cts;
class axes; class axes;
class layouts; class layouts;
class swizzle;
} }
// typedef // typedef
typedef llvm::IRBuilder<llvm::ConstantFolder, typedef llvm::IRBuilder<llvm::ConstantFolder,
@@ -43,17 +52,14 @@ typedef llvm::IRBuilder<llvm::ConstantFolder,
typedef llvm::LLVMContext LLVMContext; typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type; typedef llvm::Type Type;
typedef llvm::Value Value; typedef llvm::Value Value;
typedef llvm::Attribute Attribute;
typedef llvm::BasicBlock BasicBlock;
typedef llvm::Module Module; typedef llvm::Module Module;
typedef llvm::Instruction Instruction; typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant; typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType; typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function; typedef llvm::Function Function;
typedef std::vector<Value*> indices_t; typedef std::vector<Value*> indices_t;
// forward
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;
class target; class target;
} }
@@ -62,110 +68,129 @@ class target;
namespace triton{ namespace triton{
namespace codegen{ namespace codegen{
struct distributed_axis {
int contiguous;
std::vector<Value*> values;
Value* thread_id;
};
class generator: public ir::visitor, public analysis::layout_visitor { class generator: public ir::visitor, public analysis::layout_visitor {
private: private:
void for_each(ir::value *x, const std::function<void(indices_t)>& fn); void init_idx(ir::value *x);
Value* get_value(ir::value *x, const indices_t& idx); Instruction* add_barrier();
void set_value(ir::value *x, const indices_t& idx, Value* v); Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add);
void finalize_shared_layout(analysis::shared_layout*); void finalize_shared_layout(analysis::shared_layout*);
void finalize_function(ir::function*); void finalize_function(ir::function*);
void finalize_phi_node(ir::phi_node*); void finalize_phi_node(ir::phi_node*);
private:
Type *cvt(ir::type *ty);
llvm::Attribute cvt(ir::attribute attr);
public: public:
generator(analysis::axes *a_axes, generator(analysis::axes *a_axes,
analysis::layouts *layouts, analysis::layouts *layouts,
analysis::align *alignment, analysis::align *alignment,
analysis::allocation *alloc, analysis::allocation *alloc,
analysis::swizzle *swizzle,
target *tgt, target *tgt,
unsigned num_warps); unsigned num_warps);
void visit_value(ir::value* v); void visit_value(ir::value* v);
void visit_phi_node(ir::phi_node*); void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*); void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*); void visit_getelementptr_inst(ir::getelementptr_inst*);
void visit_icmp_inst(ir::icmp_inst*); void visit_icmp_inst(ir::icmp_inst*);
void visit_fcmp_inst(ir::fcmp_inst*); void visit_fcmp_inst(ir::fcmp_inst*);
void visit_cast_inst(ir::cast_inst*); void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*); void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*); void visit_cond_branch_inst(ir::cond_branch_inst*);
void visit_uncond_branch_inst(ir::uncond_branch_inst*); void visit_uncond_branch_inst(ir::uncond_branch_inst*);
void visit_load_inst(ir::load_inst*);
void visit_unmasked_load_inst(ir::unmasked_load_inst*); void visit_unmasked_load_inst(ir::unmasked_load_inst*);
void visit_masked_load_inst(ir::masked_load_inst*); void visit_masked_load_inst(ir::masked_load_inst*);
void visit_store_inst(ir::store_inst*);
void visit_unmasked_store_inst(ir::unmasked_store_inst*); void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*); void visit_masked_store_inst(ir::masked_store_inst*);
void visit_reshape_inst(ir::reshape_inst*); void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*); void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*); void visit_broadcast_inst(ir::broadcast_inst*);
void visit_downcast_inst(ir::downcast_inst*); void visit_downcast_inst(ir::downcast_inst*);
void visit_exp_inst(ir::exp_inst*); void visit_exp_inst(ir::exp_inst*);
void visit_log_inst(ir::log_inst*); void visit_log_inst(ir::log_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*); void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_program_inst(ir::get_num_program_inst*); void visit_get_num_program_inst(ir::get_num_program_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*); void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_exch_inst(ir::atomic_exch_inst*); void visit_atomic_exch_inst(ir::atomic_exch_inst*);
void visit_atomic_add_inst(ir::atomic_add_inst*); void visit_atomic_add_inst(ir::atomic_add_inst*);
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_dot_inst(ir::dot_inst*); void visit_dot_inst(ir::dot_inst*);
void visit_trans_inst(ir::trans_inst*); void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*); void visit_sqrt_inst(ir::sqrt_inst*);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reduce_inst(ir::reduce_inst*); void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*); void visit_select_inst(ir::select_inst*);
void visit_recoalesce_inst(ir::recoalesce_inst*); void visit_recoalesce_inst(ir::recoalesce_inst*);
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*); void visit_barrier_inst(ir::barrier_inst*);
void visit_async_wait_inst(ir::async_wait_inst*);
void visit_make_range_dyn(ir::make_range_dyn*); void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*); void visit_make_range(ir::make_range*);
void visit_make_range_sta(ir::make_range_sta*); void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*); void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*); void visit_constant_int(ir::constant_int*);
void visit_constant_fp(ir::constant_fp*); void visit_constant_fp(ir::constant_fp*);
void visit_alloc_const(ir::alloc_const*); void visit_alloc_const(ir::alloc_const*);
void visit_function(ir::function*); void visit_function(ir::function*);
void visit_basic_block(ir::basic_block*); void visit_basic_block(ir::basic_block*);
void visit_argument(ir::argument*); void visit_argument(ir::argument*);
void visit(ir::module &, llvm::Module &);
void visit_layout_hmma_884(analysis::mma884_layout*); // layouts
void visit_layout_mma(analysis::mma_layout*);
void visit_layout_scanline(analysis::scanline_layout*); void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*); void visit_layout_shared(analysis::shared_layout*);
void visit(ir::module &, llvm::Module &);
private: private:
LLVMContext *ctx_; LLVMContext *ctx_;
Builder* builder_; Builder* builder_;
Module *mod_; Module *mod_;
std::map<const analysis::data_layout*, machine_data_layout*> machine_layouts_;
analysis::axes *a_axes_; analysis::axes *a_axes_;
analysis::swizzle *swizzle_;
std::map<unsigned, distributed_axis> axes_; std::map<unsigned, distributed_axis> axes_;
std::map<ir::value *, Value *> vmap_;
std::map<ir::value *, tile *> tmap_;
target *tgt_; target *tgt_;
analysis::layouts *layouts_; analysis::layouts *layouts_;
analysis::align *alignment_; analysis::align *alignment_;
analysis::allocation *alloc_; analysis::allocation *alloc_;
Value *sh_mem_ptr_; Value *shmem_;
unsigned num_warps_; unsigned num_warps_;
std::set<ir::value*> seen_; std::set<ir::value*> seen_;
std::map<analysis::data_layout*, Value*> offset_a_m_;
std::map<analysis::data_layout*, Value*> offset_a_k_;
std::map<analysis::data_layout*, Value*> offset_b_k_;
std::map<analysis::data_layout*, Value*> offset_b_n_;
std::map<analysis::data_layout*, Value*> shared_ptr_;
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
std::map<analysis::data_layout*, Value*> shared_off_;
std::map<ir::value*, Value*> shmems_;
std::map<ir::value*, Value*> shoffs_;
std::map<ir::value*, std::vector<indices_t>> idxs_;
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;
}; };
} }

View File

@@ -1,138 +0,0 @@
#pragma once
#ifndef _TRITON_SELECTION_MACHINE_LAYOUT_H_
#define _TRITON_SELECTION_MACHINE_LAYOUT_H_
#include <map>
#include "triton/codegen/analysis/layout.h"
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace ir{
class value;
}
namespace codegen{
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layouts;
}
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
class distributed_axis;
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;
class target;
}
}
namespace triton{
namespace codegen{
class machine_data_layout {
public:
virtual tile* create(ir::value *v) = 0;
};
class machine_shared_layout: public machine_data_layout {
public:
machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr,
analysis::shared_layout* layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap);
tile* create(ir::value *v);
Module *mod_;
Builder *builder_;
target *tgt_;
analysis::allocation* alloc_;
Value *&sh_mem_ptr_;
analysis::shared_layout* layout_;
std::map<ir::value *, Value *>& vmap_;
std::map<ir::value *, tile *>& tmap_;
Value *offset_;
Value *ptr_;
Value *pre_ptr_;
Value *next_ptr_;
};
class machine_distributed_layout: public machine_data_layout {
public:
machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::data_layout* layout);
tile* create(ir::value *v);
Module *mod_;
Builder *builder_;
target *tgt_;
analysis::axes *a_axes_;
std::map<unsigned, distributed_axis>& axes_;
analysis::data_layout* layout_;
};
class machine_mma884_layout: public machine_distributed_layout {
public:
machine_mma884_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::mma884_layout* layout);
Value *offset_a_i_, *offset_a_k_;
Value *offset_b_j_, *offset_b_k_;
unsigned pack_size_0_;
unsigned pack_size_1_;
unsigned num_packs_0_;
unsigned num_packs_1_;
};
class machine_scanline_layout: public machine_distributed_layout {
public:
machine_scanline_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::scanline_layout* layout);
};
}
}
#endif

View File

@@ -1,152 +0,0 @@
#pragma once
#ifndef _TRITON_SELECTION_MACHINE_VALUE_H_
#define _TRITON_SELECTION_MACHINE_VALUE_H_
#include <vector>
#include <map>
#include <functional>
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace codegen{
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
}
}
namespace triton{
namespace codegen{
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layouts;
}
class distributed_axis;
class machine_data_layout;
class tile;
class shared_tile;
class distributed_tile;
class target;
typedef std::vector<Value*> indices_t;
}
}
namespace triton{
namespace codegen{
struct distributed_axis {
int contiguous;
std::vector<Value*> values;
Value* thread_id;
};
class tile {
protected:
typedef std::vector<unsigned> shapes_t;
public:
tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
virtual void set_value(indices_t idx, Value *v) = 0;
virtual Value* get_value(indices_t idx) = 0;
Type *get_ty() const { return ty_; }
shapes_t get_shapes() const { return shapes_; }
protected:
Type *ty_;
shapes_t shapes_;
};
class shared_tile: public tile {
private:
void extract_constant(Value *arg, Value *&non_cst, Value *&cst);
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
public:
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector<int>& perm = {});
void set_vector_size(unsigned vector_size);
void set_return_mode(bool return_vector);
void set_value(indices_t, Value *);
Value* get_ptr_to(indices_t idx);
Value* get_value(indices_t idx);
Value* get_pointer() { return ptr_; }
Value* get_offset() { return offset_; }
const std::vector<int>& get_perm() { return perm_; }
const std::vector<int>& get_order() { return order_; }
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx);
private:
Value *ptr_;
bool return_vector_;
Builder &builder_;
Value *offset_;
std::map<indices_t, Value*> ptr_cache_;
unsigned vector_size_;
std::vector<int> order_;
std::vector<int> perm_;
};
// Distribtued tile
class distributed_tile: public tile{
typedef std::vector<distributed_axis> axes_t;
typedef std::vector<indices_t> ordered_indices_vec_t;
typedef std::map<indices_t, unsigned> indices_map_t;
typedef std::map<indices_t, Value*> values_map_t;
private:
void init_indices();
public:
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder);
void set_value(indices_t idx, Value *v);
Value* get_value(indices_t idx);
const std::vector<int>& get_order() { return order_; }
unsigned get_linear_index(indices_t idx);
indices_t get_ordered_indices(unsigned id);
void for_each(std::function<void(indices_t)> fn, int start = 0, int end = -1);
void for_each(std::function<void(indices_t)> fn, std::vector<int> start, std::vector<int> size);
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
private:
axes_t axes_;
std::vector<int> order_;
indices_map_t indices_;
values_map_t values_;
ordered_indices_vec_t ordered_indices_;
Builder &builder_;
};
}
}
#endif

View File

@@ -35,6 +35,8 @@ namespace codegen{
namespace triton{ namespace triton{
namespace codegen{ namespace codegen{
class nvidia_cu_target;
class target { class target {
public: public:
target(bool is_gpu): is_gpu_(is_gpu){} target(bool is_gpu): is_gpu_(is_gpu){}
@@ -47,6 +49,7 @@ public:
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0; virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0; virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
virtual unsigned guaranteed_alignment() = 0; virtual unsigned guaranteed_alignment() = 0;
nvidia_cu_target* as_nvidia();
bool is_gpu() const; bool is_gpu() const;
private: private:
@@ -68,7 +71,7 @@ public:
class nvidia_cu_target: public target { class nvidia_cu_target: public target {
public: public:
nvidia_cu_target(): target(true){} nvidia_cu_target(int sm): target(true), sm_(sm){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn); void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder); Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder); Instruction* add_memfence(Module *module, Builder& builder);
@@ -76,7 +79,11 @@ public:
Value* get_local_id(Module *module, Builder& builder, unsigned ax); Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax); Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
int sm() { return sm_; }
unsigned guaranteed_alignment() { return 16; } unsigned guaranteed_alignment() { return 16; }
private:
int sm_;
}; };
class cpu_target: public target { class cpu_target: public target {

View File

@@ -11,14 +11,22 @@ namespace ir {
class value; class value;
class phi_node; class phi_node;
class instruction; class instruction;
class builder;
} }
namespace codegen{ namespace codegen{
namespace transform{ namespace transform{
class cts { class cts {
private:
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
public: public:
cts(bool use_async = false): use_async_(use_async) {}
void run(ir::module &mod); void run(ir::module &mod);
private:
bool use_async_;
}; };
} }

View File

@@ -1,6 +1,8 @@
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H #ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
#define TDL_INCLUDE_CODEGEN_BARRIERS_H #define TDL_INCLUDE_CODEGEN_BARRIERS_H
#include <vector>
namespace triton { namespace triton {
namespace ir { namespace ir {
@@ -31,14 +33,14 @@ private:
private: private:
interval_vec_t join(const std::vector<interval_vec_t>& intervals); interval_vec_t join(const std::vector<interval_vec_t>& intervals);
void insert_barrier(ir::instruction *instr, ir::builder &builder); void insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder);
bool intersect(const interval_vec_t &X, interval_t x); bool intersect(const interval_vec_t &X, interval_t x);
bool intersect(const interval_vec_t &X, const interval_vec_t &Y); bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
void add_reference(ir::value *v, interval_vec_t &res); void add_reference(ir::value *v, interval_vec_t &res);
void get_read_intervals(ir::instruction *i, interval_vec_t &res); void get_read_intervals(ir::instruction *i, interval_vec_t &res);
void get_written_intervals(ir::instruction *i, interval_vec_t &res); void get_written_intervals(ir::instruction *i, interval_vec_t &res);
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war); std::map<triton::ir::instruction *, std::pair<bool, bool> > &insert_loc, std::set<triton::ir::value *> &safe_war, std::vector<triton::ir::instruction *> &to_sync);
public: public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc): membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):

View File

@@ -1,6 +1,7 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H #ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H #define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#include "triton/codegen/target.h"
namespace triton { namespace triton {
@@ -27,12 +28,16 @@ private:
bool rewrite_mult(ir::instruction *value, ir::builder& builder); bool rewrite_mult(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder); bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder); bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
private: private:
public: public:
peephole() {} peephole(target* tgt): tgt_(tgt) {}
void run(ir::module &mod); void run(ir::module &mod);
private:
target* tgt_;
}; };

View File

@@ -0,0 +1,26 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
namespace triton {
// forward declaration
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class reorder {
public:
void run(ir::module& module);
};
}
}
}
#endif

View File

@@ -39,43 +39,23 @@ public:
// CUDA device // CUDA device
class cu_device: public device { class cu_device: public device {
public:
//Supported architectures
enum class Architecture{
//NVidia
SM_2_0,
SM_2_1,
SM_3_0,
SM_3_5,
SM_3_7,
SM_5_0,
SM_5_2,
SM_6_0,
SM_6_1,
SM_7_0,
UNKNOWN
};
private: private:
//Metaprogramming elper to get cuda info from attribute //Metaprogramming elper to get cuda info from attribute
template<CUdevice_attribute attr> template<CUdevice_attribute attr>
int cuGetInfo() const; int cuGetInfo() const;
inline Architecture nv_arch(std::pair<unsigned int, unsigned int> sm) const;
inline nvmlDevice_t nvml_device() const; inline nvmlDevice_t nvml_device() const;
public: public:
cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){} cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){}
// Accessors
Architecture architecture() const;
// Informations // Informations
std::string infos() const; std::string infos() const;
size_t address_bits() const; size_t address_bits() const;
std::vector<size_t> max_block_dim() const; std::vector<size_t> max_block_dim() const;
size_t warp_size() const; size_t warp_size() const;
// Compute Capability // Compute Capability
void interpret_as(std::pair<size_t, size_t> cc); void interpret_as(int cc);
std::pair<size_t, size_t> compute_capability() const; int compute_capability() const;
// Identifier // Identifier
std::string name() const; std::string name() const;
std::string pci_bus_id() const; std::string pci_bus_id() const;
@@ -91,7 +71,7 @@ public:
std::unique_ptr<codegen::target> make_target() const; std::unique_ptr<codegen::target> make_target() const;
private: private:
std::shared_ptr<std::pair<size_t, size_t>> interpreted_as_; std::shared_ptr<int> interpreted_as_;
}; };
} }

View File

@@ -19,18 +19,18 @@ namespace triton
namespace nvrtc namespace nvrtc
{ {
#define ISAAC_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } } #define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
ISAAC_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory"); TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
ISAAC_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure"); TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input"); TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program"); TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option"); TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
ISAAC_CREATE_NVRTC_EXCEPTION(compilation ,"compilation"); TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
ISAAC_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure"); TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
ISAAC_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error"); TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
#undef ISAAC_CREATE_NVRTC_EXCEPTION #undef TRITON_CREATE_NVRTC_EXCEPTION
} }
@@ -38,107 +38,107 @@ namespace triton
{ {
class base: public std::exception{}; class base: public std::exception{};
#define ISAAC_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } } #define TRITON_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
ISAAC_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value"); TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
ISAAC_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory"); TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
ISAAC_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized"); TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
ISAAC_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized"); TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled"); TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized"); TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started"); TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped"); TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
ISAAC_CREATE_CUDA_EXCEPTION(no_device ,"no device"); TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device"); TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image"); TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context"); TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
ISAAC_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current"); TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
ISAAC_CREATE_CUDA_EXCEPTION(map_failed ,"map failed"); TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
ISAAC_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed"); TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
ISAAC_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped"); TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
ISAAC_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped"); TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
ISAAC_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu"); TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
ISAAC_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired"); TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped"); TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array"); TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer"); TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
ISAAC_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable"); TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
ISAAC_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit"); TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
ISAAC_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use"); TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported"); TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx"); TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context"); TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source"); TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
ISAAC_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found"); TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
ISAAC_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found"); TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
ISAAC_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed"); TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
ISAAC_CREATE_CUDA_EXCEPTION(operating_system ,"operating system"); TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle"); TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
ISAAC_CREATE_CUDA_EXCEPTION(not_found ,"not found"); TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found");
ISAAC_CREATE_CUDA_EXCEPTION(not_ready ,"not ready"); TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
ISAAC_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address"); TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
ISAAC_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources"); TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
ISAAC_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout"); TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
ISAAC_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing"); TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled"); TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled"); TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
ISAAC_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active"); TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
ISAAC_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed"); TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
ISAAC_CREATE_CUDA_EXCEPTION(assert_error ,"assert"); TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
ISAAC_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers"); TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
ISAAC_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered"); TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
ISAAC_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered"); TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
ISAAC_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error"); TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
ISAAC_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction"); TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
ISAAC_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address"); TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space"); TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
ISAAC_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc"); TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
ISAAC_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed"); TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
ISAAC_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted"); TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
ISAAC_CREATE_CUDA_EXCEPTION(not_supported ,"not supported"); TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
ISAAC_CREATE_CUDA_EXCEPTION(unknown ,"unknown"); TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
#undef ISAAC_CREATE_CUDA_EXCEPTION #undef TRITON_CREATE_CUDA_EXCEPTION
} }
namespace cublas namespace cublas
{ {
class base: public std::exception{}; class base: public std::exception{};
#define ISAAC_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } } #define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
ISAAC_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized"); TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
ISAAC_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed"); TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
ISAAC_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value"); TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
ISAAC_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch"); TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
ISAAC_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error"); TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
ISAAC_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed"); TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
ISAAC_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error"); TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
ISAAC_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported"); TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
ISAAC_CREATE_CUBLAS_EXCEPTION(license_error ,"license error"); TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
ISAAC_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown"); TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
#undef ISAAC_CREATE_CUBLAS_EXCEPTION #undef TRITON_CREATE_CUBLAS_EXCEPTION
} }
namespace cudnn namespace cudnn
{ {
#define ISAAC_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } } #define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
ISAAC_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized"); TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
ISAAC_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed"); TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
ISAAC_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param"); TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
ISAAC_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error"); TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
ISAAC_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value"); TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
ISAAC_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch"); TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
ISAAC_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error"); TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
ISAAC_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed"); TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
ISAAC_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported"); TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
ISAAC_CREATE_CUDNN_EXCEPTION(license_error ,"license error"); TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing"); TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress"); TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow"); TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
} }
} }

View File

@@ -44,6 +44,13 @@ public:
const std::string &features, const std::string &features,
file_type_t file_type); file_type_t file_type);
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0; virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
std::string llir() const { return llir_; }
int spilled() const { return spilled_; }
private:
std::string llir_;
protected:
int spilled_;
}; };
// CPU // CPU
@@ -59,12 +66,12 @@ class cu_module: public module {
public: public:
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module); cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
cu_module(const std::string& source); cu_module(driver::device* device, const std::string& source);
std::unique_ptr<buffer> symbol(const char * name) const; std::unique_ptr<buffer> symbol(const char * name) const;
const std::string& source() const { return source_; } const std::string& ptx() const { return ptx_; }
private: private:
std::string source_; std::string ptx_;
}; };

View File

@@ -146,8 +146,10 @@ public:
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
// Intrinsics // Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = ""); value *create_copy_to_shared(value *arg, const std::string &name = "");
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_copy_from_shared(value *arg, const std::string &name = ""); value *create_copy_from_shared(value *arg, const std::string &name = "");
value *create_barrier(const std::string &name = ""); value *create_barrier(const std::string &name = "");
value *create_async_wait();
private: private:
context &ctx_; context &ctx_;

View File

@@ -7,7 +7,7 @@ namespace triton{
namespace ir{ namespace ir{
enum binary_op_t { enum binary_op_t: unsigned int{
Add, Add,
FAdd, FAdd,
Sub, Sub,
@@ -28,7 +28,7 @@ enum binary_op_t {
Xor Xor
}; };
enum cast_op_t { enum cast_op_t: unsigned int {
Trunc, Trunc,
ZExt, ZExt,
SExt, SExt,
@@ -44,7 +44,7 @@ enum cast_op_t {
AddrSpaceCast AddrSpaceCast
}; };
enum cmp_pred_t { enum cmp_pred_t: unsigned int {
FIRST_FCMP_PREDICATE, FIRST_FCMP_PREDICATE,
FCMP_FALSE, FCMP_FALSE,
FCMP_OEQ, FCMP_OEQ,
@@ -113,6 +113,7 @@ enum value_id_t: unsigned {
// io // io
INST_UNMASKED_LOAD, INST_UNMASKED_LOAD,
INST_MASKED_LOAD, INST_MASKED_LOAD,
INST_MASKED_LOAD_ASYNC,
INST_UNMASKED_STORE, INST_UNMASKED_STORE,
INST_MASKED_STORE, INST_MASKED_STORE,
// retile // retile
@@ -139,6 +140,7 @@ enum value_id_t: unsigned {
INST_COPY_FROM_SHARED, INST_COPY_FROM_SHARED,
INST_RECOALESCE, INST_RECOALESCE,
INST_BARRIER, INST_BARRIER,
INST_ASYNC_WAIT,
INST_MAKE_RANGE_DYN, INST_MAKE_RANGE_DYN,
INST_MAKE_RANGE_STA, INST_MAKE_RANGE_STA,
INST_MAKE_RANGE INST_MAKE_RANGE

View File

@@ -72,6 +72,7 @@ public:
case noalias: return ".noalias"; case noalias: return ".noalias";
case aligned: return ".aligned(" + std::to_string(value_) + ")"; case aligned: return ".aligned(" + std::to_string(value_) + ")";
case multiple_of: return ".readonly"; case multiple_of: return ".readonly";
case retune: return ".retunr";
default: break; default: break;
} }
assert(false); assert(false);

View File

@@ -64,9 +64,10 @@ public:
// cloning // cloning
ir::instruction* clone() { ir::instruction* clone() {
ir::instruction* res = clone_impl(); ir::instruction* res = clone_impl();
for(auto it = op_begin(); it != op_end(); it++) // for(auto it = op_begin(); it != op_end(); it++)
(*it)->add_use(res); // (*it)->add_use(res);
res->parent_ = nullptr; res->parent_ = nullptr;
res->users_.clear();
return res; return res;
} }
// instruction id // instruction id
@@ -431,6 +432,25 @@ public:
_TRITON_DEFINE_ACCEPT(masked_load_inst) _TRITON_DEFINE_ACCEPT(masked_load_inst)
}; };
// masked load async
class masked_load_async_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load_async_async"; }
masked_load_async_inst(value *ptr, value *mask, value *false_value,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(1); }
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_async_inst)
_TRITON_DEFINE_ACCEPT(masked_load_async_inst)
};
class atomic_add_inst: public io_inst { class atomic_add_inst: public io_inst {
private: private:
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
@@ -757,6 +777,7 @@ public:
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst) _TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
}; };
class recoalesce_inst: public unary_inst{ class recoalesce_inst: public unary_inst{
private: private:
using unary_inst::unary_inst; using unary_inst::unary_inst;
@@ -780,6 +801,18 @@ public:
instruction *next = nullptr); instruction *next = nullptr);
}; };
class async_wait_inst: public instruction{
private:
async_wait_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait"; }
_TRITON_DEFINE_CLONE(async_wait_inst)
_TRITON_DEFINE_ACCEPT(async_wait_inst)
public:
static async_wait_inst* create(context &ctx, const std::string &name = "",
instruction *next = nullptr);
};
// On NVIDIA, implementation is such that // On NVIDIA, implementation is such that
// constant_range = nv_dynamic_program_idx + nv_static_program_idx // constant_range = nv_dynamic_program_idx + nv_static_program_idx
// so as to enable re-association on nv_static_program_idx which is constant // so as to enable re-association on nv_static_program_idx which is constant

View File

@@ -65,7 +65,9 @@ class select_inst;
class recoalesce_inst; class recoalesce_inst;
class copy_to_shared_inst; class copy_to_shared_inst;
class copy_from_shared_inst; class copy_from_shared_inst;
class masked_load_async_inst;
class barrier_inst; class barrier_inst;
class async_wait_inst;
class make_range_dyn; class make_range_dyn;
class make_range; class make_range;
@@ -139,7 +141,9 @@ public:
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0; virtual void visit_recoalesce_inst(recoalesce_inst*) = 0;
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0; virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0; virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
virtual void visit_barrier_inst(barrier_inst*) = 0; virtual void visit_barrier_inst(barrier_inst*) = 0;
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
virtual void visit_make_range_dyn(make_range_dyn*) = 0; virtual void visit_make_range_dyn(make_range_dyn*) = 0;
virtual void visit_make_range(make_range*) = 0; virtual void visit_make_range(make_range*) = 0;

View File

@@ -0,0 +1,34 @@
#pragma once
#ifndef _TRITON_RUNTIME_ERROR_H_
#define _TRITON_RUNTIME_ERROR_H_
#include <exception>
#include <string>
namespace triton {
namespace runtime{
namespace exception {
class base: public std::exception {};
#define TRITON_CREATE_RUNTIME_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "Triton: Error - Runtime: " msg; } };
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_shared_memory, "out of shared memory")
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_registers, "out of registers")
class no_valid_configuration: public exception::base {
public:
no_valid_configuration(const std::string& err): err_(err) { }
const char * what() const throw(){ return err_.c_str(); }
private:
std::string err_;
};
#undef TRITON_CREATE_RUNTIME_EXCEPTION
}
}
}
#endif

View File

@@ -6,6 +6,7 @@
#include <map> #include <map>
#include <vector> #include <vector>
#include <string> #include <string>
#include <sstream>
#include <memory> #include <memory>
#include <functional> #include <functional>
#include <set> #include <set>
@@ -13,6 +14,7 @@
#include "triton/ir/context.h" #include "triton/ir/context.h"
#include "triton/codegen/target.h" #include "triton/codegen/target.h"
#include "triton/runtime/arg.h" #include "triton/runtime/arg.h"
#include "triton/runtime/error.h"
namespace llvm { namespace llvm {
class Module; class Module;
@@ -56,33 +58,43 @@ template<typename T> inline T convert(const std::string& name);
template<> inline long convert<long>(const std::string& name) { return std::stol(name); } template<> inline long convert<long>(const std::string& name) { return std::stol(name); }
template<> inline int convert<int>(const std::string& name) { return std::stoi(name); } template<> inline int convert<int>(const std::string& name) { return std::stoi(name); }
template<class T>
void add_arg(std::stringstream& ss, T arg) {
ss.write((char*)&arg, sizeof(T));
}
enum asm_mode_t {
ASM_LLIR,
ASM_NV_PTX,
ASM_NV_SASS
};
struct options_space_t {
typedef std::pair<std::string, std::vector<std::string>> define_t;
std::vector<define_t> defines;
std::vector<int> num_warps;
std::vector<int> recompile_key;
};
struct options_t {
template<class T>
T D(const std::string& name) const {
return convert<T>(defines.at(name));
}
bool operator<(const options_t& other) const {
return std::make_pair(defines, num_warps) <
std::make_pair(other.defines, other.num_warps);
}
std::string to_str() const;
std::map<std::string, std::string> defines;
size_t num_warps;
};
class function { class function {
public: public:
struct options_space_t {
typedef std::pair<std::string, std::vector<std::string>> define_t;
std::vector<define_t> defines;
std::vector<int> num_warps;
std::vector<int> recompile_key;
};
struct options_t {
template<class T>
T D(const std::string& name) const {
return convert<T>(defines.at(name));
}
bool operator<(const options_t& other) const {
return std::make_pair(defines, num_warps) <
std::make_pair(other.defines, other.num_warps);
}
std::string to_str() const;
std::map<std::string, std::string> defines;
size_t num_warps;
};
typedef std::function<grid_t(const options_t&)> grid_fn_ty; typedef std::function<grid_t(const options_t&)> grid_fn_ty;
private: private:
class caller { class caller {
public: public:
@@ -135,7 +147,7 @@ public:
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device); void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device);
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device); void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device);
void set_cst(const char* name, void* data, size_t n_bytes); void set_cst(const char* name, void* data, size_t n_bytes);
std::string ptx(driver::device *device, const options_t& opt); std::string get_asm(asm_mode_t mode, driver::device *device, const options_t& opt);
private: private:
std::map<std::string, std::vector<char>> cst_; std::map<std::string, std::vector<char>> cst_;

View File

@@ -33,25 +33,20 @@ private:
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false) inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
{ {
// const driver::device * device = stream->context()->device(); // const driver::device * device = stream->context()->device();
size_t warmup = 10;
size_t repeat = 50;
timer tmr; timer tmr;
std::vector<size_t> times; std::vector<size_t> times;
double total_time = 0; double total_time = 0;
op(); for(size_t i = 0; i < warmup; i++)
op();
stream->synchronize(); stream->synchronize();
tmr.start(); tmr.start();
for(size_t i = 0; i < 10; i++){ for(size_t i = 0; i < repeat; i++){
// while(total_time*1e-9 < 1e-2){
// float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning
// if(normalize)
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
op(); op();
// times.push_back(norm*tmr.get().count());
// total_time+=times.back();
} }
stream->synchronize(); stream->synchronize();
return (float)tmr.get().count() / 10; return (float)tmr.get().count() / repeat;
// return *std::min_element(times.begin(), times.end()); // return *std::min_element(times.begin(), times.end());
} }

View File

@@ -79,7 +79,7 @@ void axes::update_graph_dot(ir::instruction *i) {
graph_.add_edge({dot, d}, {D, d}); graph_.add_edge({dot, d}, {D, d});
} }
void axes::update_graph_elementwise(ir::instruction *i) { void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
if(i->get_num_operands() == 0) if(i->get_num_operands() == 0)
return; return;
ir::value *op = i->get_operand(0); ir::value *op = i->get_operand(0);
@@ -89,7 +89,7 @@ void axes::update_graph_elementwise(ir::instruction *i) {
for(unsigned d = 0; d < rank; d++) for(unsigned d = 0; d < rank; d++)
for(ir::value* opx: i->ops()) for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){ for(ir::value* opy: i->ops()){
if(!i->get_type()->is_void_ty()) if(connect_ret && !i->get_type()->is_void_ty())
graph_.add_edge({i, d}, {opx, d}); graph_.add_edge({i, d}, {opx, d});
graph_.add_edge({opx, d}, {opy, d}); graph_.add_edge({opx, d}, {opy, d});
} }
@@ -111,7 +111,8 @@ void axes::update_graph(ir::instruction *i) {
case ir::INST_TRANS: return update_graph_trans(i); case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i); case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i); case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);; case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false);
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i); case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_RECOALESCE: return update_graph_no_edge(i); case ir::INST_RECOALESCE: return update_graph_no_edge(i);
default: return update_graph_elementwise(i); default: return update_graph_elementwise(i);

View File

@@ -55,7 +55,7 @@ inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
for(ir::user* u: v->get_users()){ for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u); auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v) if(i && is_hmma_c(i) && i->get_operand(n) == v)
result = v; result = i;
} }
} }
@@ -115,8 +115,10 @@ data_layout::data_layout(id_t id,
} }
} }
size_t data_layout::find_axis(int to_find) const { int data_layout::find_axis(int to_find) const {
auto it = std::find(axes_.begin(), axes_.end(), to_find); auto it = std::find(axes_.begin(), axes_.end(), to_find);
if(it == axes_.end())
return -1;
return std::distance(axes_.begin(), it); return std::distance(axes_.begin(), it);
} }
@@ -125,23 +127,41 @@ size_t data_layout::find_axis(int to_find) const {
* MMA Layout * * MMA Layout *
* -------------------------------- */ * -------------------------------- */
mma884_layout::mma884_layout(size_t num_warps, mma_layout::mma_layout(size_t num_warps,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<unsigned>& shape, const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) { analysis::align* align, target* tgt,
shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) {
/* fragments per warp */ /* fragments per warp */
// try to make things as square as possible to maximize data re-use // try to make things as square as possible to maximize data re-use
fpw_ = {1, 1, 1}; if(tgt->as_nvidia()->sm() < 80){
std::vector<int> fpw_nm1; fpw_ = {1, 1, 1};
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4); std::vector<int> fpw_nm1;
do { unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
fpw_nm1 = fpw_; do {
if(fpw_[0]*fpw_[1] < num_fragments) fpw_nm1 = fpw_;
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8); if(fpw_[0]*fpw_[1] < num_fragments)
if(fpw_[0]*fpw_[1] < num_fragments) fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8); if(fpw_[0]*fpw_[1] < num_fragments)
}while(fpw_nm1 != fpw_); fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
}while(fpw_nm1 != fpw_);
auto ord_a = layout_a->get_order();
auto ord_b = layout_b->get_order();
bool is_a_row = ord_a[0] != 0;
bool is_b_row = ord_b[0] != 0;
bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
spw_ = {fpw_[0]*8*pack_size_0, fpw_[1]*8*pack_size_1, 1};
}
else{
fpw_ = {1, 1, 1};
spw_ = {16, 8, 1};
rep_ = {2, 2, 1};
}
/* warps per tile */ /* warps per tile */
// try to make things as square as possible to maximize data re-use // try to make things as square as possible to maximize data re-use
@@ -150,17 +170,13 @@ mma884_layout::mma884_layout(size_t num_warps,
do{ do{
wpt_nm1 = wpt_; wpt_nm1 = wpt_;
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8)); wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8)); wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
}while(wpt_nm1 != wpt_); }while(wpt_nm1 != wpt_);
/* sanity check */ /* shape per block */
unsigned effective_num_warps = 1; spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
for(size_t d = 0; d < shape.size(); d++)
effective_num_warps *= wpt_[d];
// if(num_warps != effective_num_warps)
// throw std::runtime_error("cannot create a kernel with this amount of warps");
} }
@@ -183,13 +199,15 @@ scanline_layout::scanline_layout(size_t num_warps,
ir::value *ptr = nullptr; ir::value *ptr = nullptr;
for(ir::value *v: values) for(ir::value *v: values)
for(ir::user *usr: v->get_users()) for(ir::user *usr: v->get_users())
if(auto *st = dynamic_cast<ir::store_inst*>(usr)) if(auto *st = dynamic_cast<ir::io_inst*>(usr))
ptr = st->get_pointer_operand(); ptr = st->get_pointer_operand();
unsigned i = order_[0]; unsigned i = order_[0];
int contiguous = 4; int contiguous = 1;
if(ptr) if(ptr){
contiguous = std::min<int>(align->contiguous(ptr)[i], 4); int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits);
}
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i])); nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
@@ -204,14 +222,6 @@ scanline_layout::scanline_layout(size_t num_warps,
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts_[i]; num_threads = num_threads / mts_[i];
} }
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shape_.size(); d++)
effective_num_threads *= mts_[d];
// std::cout <<values.size() << " " << num_warps << " " << effective_num_threads << std::endl;
// if(num_warps * 32 != effective_num_threads)
// throw std::runtime_error("cannot create a kernel with this amount of warps");
} }
@@ -246,9 +256,9 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
ir::value *value_1 = phi->get_incoming_value(1); ir::value *value_1 = phi->get_incoming_value(1);
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0); ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1); ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
if(!i_0 || !i_1 || if(!(i_0 && !i_1) &&
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) || !(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) ) !(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
return; return;
if(is_latch_1) if(is_latch_1)
res.reset(new double_buffer_info_t{value_0, value_1, phi}); res.reset(new double_buffer_info_t{value_0, value_1, phi});
@@ -257,7 +267,7 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
} }
shared_layout::shared_layout(const data_layout *arg, shared_layout::shared_layout(data_layout *arg,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<unsigned>& shape, const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
@@ -265,6 +275,7 @@ shared_layout::shared_layout(const data_layout *arg,
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) { analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
size_ = 0; size_ = 0;
arg_layout_ = arg;
// double-buffering // double-buffering
for(ir::value *v: values) for(ir::value *v: values)
@@ -284,36 +295,8 @@ shared_layout::shared_layout(const data_layout *arg,
extract_hmma_dot_use(v, hmma_dot_a, 0); extract_hmma_dot_use(v, hmma_dot_a, 0);
extract_hmma_dot_use(v, hmma_dot_b, 1); extract_hmma_dot_use(v, hmma_dot_b, 1);
} }
hmma_dot_a_ = hmma_dot_a;
hmma_dot_b_ = hmma_dot_b;
// non-mma ordering
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
for(size_t s = 2; s < get_rank(); s++){
col.push_back(s);
row.push_back(s);
}
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)
order_ = is_trans(dot_a) ? row : col;
else if(is_nonhmma_dot_b)
order_ = is_trans(dot_b) ? col : row;
// padding
size_t pad = 0;
if(hmma_dot_a){
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
pad = 24 - shape_[row ? 0 : 1] % 32;
}
else if(hmma_dot_b){
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
pad = 24 - shape_[row ? 1 : 0] % 32;
}
else if(order_ != arg_order) {
pad = 4;
}
shape_[order_[0]] += pad;
// size // size
size_ = ty_->get_primitive_size_in_bits() / 8; size_ = ty_->get_primitive_size_in_bits() / 8;
@@ -362,6 +345,8 @@ void layouts::make_graph(ir::instruction *i) {
} }
void layouts::create(size_t id, const std::vector<ir::value*>& values) { void layouts::create(size_t id, const std::vector<ir::value*>& values) {
// if(layouts_.find(id) != layouts_.end())
// return;
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c); auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
auto cmp = [](ir::value* x, ir::value *y) { auto cmp = [](ir::value* x, ir::value *y) {
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()}; std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
@@ -374,19 +359,27 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
const auto& axes = axes_->get(largest); const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_tile_shapes(); const auto& shapes = largest->get_type()->get_tile_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) { auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
return dynamic_cast<ir::copy_to_shared_inst*>(v); return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v);
}); });
// type // type
if(it_hmma_c != values.end()) if(it_hmma_c != values.end()){
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_); ir::instruction *dot = (ir::instruction*)*it_hmma_c;
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
create(groups_.at(a), values_.at(groups_.at(a)));
create(groups_.at(b), values_.at(groups_.at(b)));
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
}
else if(it_cts != values.end()){ else if(it_cts != values.end()){
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts; ir::instruction *cts = (ir::instruction*)*it_cts;
ir::value *arg = cts->get_operand(0); ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg))); create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
} }
else else{
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_); layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
}
} }
void layouts::run(ir::module &mod) { void layouts::run(ir::module &mod) {
@@ -420,7 +413,7 @@ void layouts::run(ir::module &mod) {
} }
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){ if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
ir::value *val = recoalasce->get_operand(0); ir::value *val = recoalasce->get_operand(0);
mma884_layout* in_layout = get(val)->to_mma884(); mma_layout* in_layout = get(val)->to_mma();
scanline_layout* out_layout = get(i)->to_scanline(); scanline_layout* out_layout = get(i)->to_scanline();
if(!in_layout || !out_layout) if(!in_layout || !out_layout)
return; return;
@@ -431,7 +424,7 @@ void layouts::run(ir::module &mod) {
shape[ld] = in_shape[ld]; shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++) for(size_t k = 0; k < in_shape.size(); k++)
if(k != ld) if(k != ld)
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k); shape[k] = in_layout->to_mma()->spt(k);
// create layout // create layout
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
tmp_[recoalasce] = id; tmp_[recoalasce] = id;

View File

@@ -0,0 +1,54 @@
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/target.h"
#include "triton/ir/type.h"
#include <iostream>
namespace triton{
namespace codegen{
namespace analysis{
void swizzle::run(ir::module &) {
per_phase_.clear();
max_phase_.clear();
for(auto &x: layouts_->get_all()){
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
if(!layout)
continue;
ir::value* mma_dot_a = layout->hmma_dot_a();
ir::value* mma_dot_b = layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
continue;
}
auto ord = layout->get_order();
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
if(!in_layout)
continue;
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(tgt_->as_nvidia()->sm() < 80){
int inner = mma_dot_a ? 0 : 1;
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
if(mma_dot_a)
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
else
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
}
else{
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = 8 / per_phase_[layout];
vec_[layout] = 8;
}
}
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,325 +0,0 @@
#include <numeric>
#include "triton/codegen/selection/machine_layout.h"
#include "triton/codegen/selection/machine_value.h"
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/target.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "llvm/IR/IRBuilder.h"
namespace triton{
namespace codegen{
using namespace llvm;
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
// function
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
std::vector<Type*> param_tys;
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
return FunctionType::get(return_ty, param_tys, false);
}
// pointer
if(ty->is_pointer_ty()){
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
unsigned addr_space = ty->get_pointer_address_space();
return PointerType::get(elt_ty, addr_space);
}
// integer
if(ty->is_integer_ty()){
unsigned bitwidth = ty->get_integer_bitwidth();
return IntegerType::get(ctx, bitwidth);
}
// primitive types
switch(ty->get_type_id()){
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
default: break;
}
// unknown type
throw std::runtime_error("unknown conversion from ir::type to Type");
}
// Grid construction
inline std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
size_t dim = shapes.size();
std::vector<Value*> result(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder.getInt32(shapes[order[k]]);
Value *rem = builder.CreateURem(trailing, dim_k);
trailing = builder.CreateUDiv(trailing, dim_k);
result[order[k]] = rem;
}
result[order[dim - 1]] = trailing;
return result;
}
inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
Value *&sh_mem_ptr, analysis::shared_layout *layout,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap)
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
// double-buffered
if(layout_->get_double_buffer()) {
BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout_->get_double_buffer();
ir::phi_node *phi = info.phi;
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
if(parent->empty())
builder_->SetInsertPoint(parent);
else
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
// create pointers
ptr_ = builder_->CreatePHI(ptr_ty, 2);
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
builder_->SetInsertPoint(current);
}
else{
size_t offset = alloc_->offset(layout_);
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
}
}
tile* machine_shared_layout::create(ir::value *v) {
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
auto double_buffer = layout_->get_double_buffer();
// offset
Value *offset = nullptr;
if(double_buffer && v == double_buffer->phi)
offset = offset_;
// base pointer
Value *ptr = ptr_;
if(double_buffer && v == double_buffer->latch)
ptr = next_ptr_;
else if(double_buffer && v == double_buffer->first)
ptr = pre_ptr_;
// create tile
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
}
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
analysis::data_layout *layout)
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
}
tile *machine_distributed_layout::create(ir::value *v) {
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
const auto &shapes = v->get_type()->get_tile_shapes();
size_t rank = shapes.size();
std::vector<distributed_axis> axes(rank);
std::vector<int> order(rank);
// compute axes
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
axes[d] = axes_.at(x);
}
else{
axes[d].contiguous = 1;
axes[d].values = {builder_->getInt32(0)};
}
}
// compute order
std::iota(order.begin(), order.end(), 0);
auto cmp = [&](int x, int y) {
unsigned axx = a_axes_->get(v, x);
unsigned axy = a_axes_->get(v, y);
size_t posx = layout_->find_axis(axx);
size_t posy = layout_->find_axis(axy);
if(posx < rank && posy < rank)
return layout_->get_order(posx) < layout_->get_order(posy);
return false;
};
std::sort(order.begin(), order.end(), cmp);
return new distributed_tile(ty, shapes, order, axes, *builder_);
}
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
target *tgt, analysis::axes *a_axes,
std::map<unsigned, distributed_axis>& axes,
analysis::mma884_layout* layout)
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
const auto& shape = layout->get_shape();
if(shape.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shape.size() >= 3;
Value *_1 = builder_->getInt32(1);
Value *_2 = builder_->getInt32(2);
Value *_3 = builder_->getInt32(3);
Value *_4 = builder_->getInt32(4);
Value *_16 = builder_->getInt32(16);
// fragments per warp
unsigned fpw_0 = layout->fpw(0);
unsigned fpw_1 = layout->fpw(1);
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
// warps per tile
unsigned wpt_0 = layout->wpt(0);
unsigned wpt_1 = layout->wpt(1);
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
// mma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
// mma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shape[0] / hmma_bts_0;
unsigned num_rep_1 = shape[1] / hmma_bts_1;
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;
/* intra warp offset */
// offset of quad in pair
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
builder_->getInt32(fpw_1 * pack_size_1_));
// Quad pair id
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
// Quad pair offset
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
/* inter warp offset */
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
/* offsets */
// a offset
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
// b offsets
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
// c offsets
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
builder_->CreateAdd(warp_offset_j, pair_b_off));
/* indices */
// i indices
std::vector<Value*> idx_i;
for(unsigned pack = 0; pack < num_packs_0_; pack++)
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned i = 0; i < 2; i++){
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
}
// j indices
std::vector<Value*> idx_j;
for(unsigned pack = 0; pack < num_packs_1_; pack++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
// z indices
std::vector<Value*> idx_z;
for(unsigned pack = 0; pack < num_rep_2; pack++)
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
/* axes */
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
}
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
target *tgt,
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
analysis::scanline_layout* layout)
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
Value *warp_size = builder_->getInt32(32);
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
auto order = layout->get_order();
const auto& shape = layout->get_shape();
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
// Delinearize
size_t dim = shape.size();
std::vector<Value*> thread_id(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
thread_id[order[k]] = rem;
}
thread_id[order[dim - 1]] = full_thread_id;
// Create axes
for(unsigned k = 0; k < dim; k++) {
int nts = layout->nts(k);
int mts = layout->mts(k);
std::string str_k = std::to_string(k);
Value *contiguous_k = builder_->getInt32(nts);
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
unsigned per_block = nts * mts;
unsigned per_thread = nts * shape[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / nts * per_block + n % nts;
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
}
}
}
}

View File

@@ -1,214 +0,0 @@
#include <numeric>
#include <iostream>
#include "llvm/IR/IRBuilder.h"
#include "triton/codegen/selection/machine_value.h"
namespace triton{
namespace codegen{
using namespace llvm;
/* Distributed Tile */
void distributed_tile::init_indices() {
std::vector<size_t> id(axes_.size(), 0);
// build
size_t k = 0;
while(true) {
indices_t current;
for(size_t d = 0; d < id.size(); d++)
current.push_back(axes_[d].values[id[d]]);
size_t sz = indices_.size();
indices_[current] = sz;
values_[current] = nullptr;
ordered_indices_.push_back(current);
id[order_[0]]++;
while(id[order_[k]] == axes_[order_[k]].values.size()){
if(k == id.size() - 1)
return;
id[order_[k++]] = 0;
id[order_[k]]++;
}
k = 0;
}
}
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
init_indices();
}
void distributed_tile::set_value(indices_t idx, Value *x) {
assert(x->getType() == ty_ && "cannot set a value of different type");
Value *&result = values_[idx];
assert(!result && "value cannot be set twice");
result = x;
}
Value* distributed_tile::get_value(indices_t idx) {
Value *result = values_.at(idx);
assert(result && "value has not been set");
return result;
}
unsigned distributed_tile::get_linear_index(indices_t idx) {
return indices_[idx];
}
indices_t distributed_tile::get_ordered_indices(unsigned id) {
return ordered_indices_.at(id);
}
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
if(end < 0)
end = ordered_indices_.size() + end + 1;
for(unsigned i = start; i < end; i++)
fn(ordered_indices_[i]);
}
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
int rank = sizes.size();
int len = 1;
for(int s: sizes)
len *= s;
for(int i = 0; i < len; i++){
indices_t idx(rank);
int current = i;
for(int k = 0; k < rank; k++){
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
current = current / sizes[k];
}
fn(idx);
}
}
/* Shared Tile */
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
if(dyn_cast<Constant>(arg)){
cst = arg;
non_cst = _0;
return;
}
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
non_cst = arg;
cst = _0;
return;
}
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
if(cst_lhs && cst_rhs){
cst = arg;
non_cst = _0;
}
else if(cst_lhs){
cst = cst_lhs;
non_cst = bin_op->getOperand(1);
}
else if(cst_rhs){
cst = cst_rhs;
non_cst = bin_op->getOperand(0);
}
else{
non_cst = arg;
cst = _0;
}
}
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
non_cst_idx.clear();
cst_idx.clear();
for(Value *idx: arg_idx){
Value *non_cst, *cst;
extract_constant(idx, non_cst, cst);
non_cst_idx.push_back(non_cst);
cst_idx.push_back(cst);
}
}
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
const std::vector<int>& perm, const std::vector<int>& order,
indices_t idx) {
// strides
std::vector<Value*> strides(shapes.size(), builder.getInt32(0));
strides[order[0]] = builder.getInt32(1);
for(size_t i = 1; i < idx.size(); i++)
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
// result
Value *result = builder.getInt32(0);
for(size_t i = 0; i < idx.size(); i++)
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
return result;
}
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
return_vector_ = false;
if(perm_.empty()){
perm_.resize(shapes.size());
std::iota(perm_.begin(), perm_.end(), 0);
}
}
void shared_tile::set_value(indices_t idx, Value *value) {
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
builder_.CreateStore(value, ptr);
}
void shared_tile::set_vector_size(unsigned vector_size) {
vector_size_ = vector_size;
}
void shared_tile::set_return_mode(bool return_vector){
return_vector_ = return_vector;
}
Value* shared_tile::get_value(indices_t idx) {
indices_t non_cst_idx, cst_idx;
extract_constant(idx, non_cst_idx, cst_idx);
Value *&base_ptr = ptr_cache_[non_cst_idx];
unsigned vector_size = vector_size_;
Type *ty = ty_;
if(ty->isHalfTy() && (vector_size % 2 == 0)){
ty = IntegerType::get(ty->getContext(), 32);
vector_size = vector_size / 2;
}
if(base_ptr == nullptr){
// BasicBlock* store = builder_.GetInsertBlock();
// if(!non_cst_idx.empty())
// if(isa<Instruction>(non_cst_idx.front())){
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
// }
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
if(vector_size_ > 1){
Type *vec_ty = VectorType::get(ty, vector_size);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
}
// builder_.SetInsertPoint(store);
}
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
Value *div = offset;
if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
Value *ptr = builder_.CreateGEP(base_ptr, div);
Value *result = builder_.CreateLoad(ptr);
if(return_vector_ == false && vector_size_ > 1) {
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
result = builder_.CreateExtractElement(result, rem);
}
return result;
}
}
}

View File

@@ -14,6 +14,12 @@ namespace triton{
namespace codegen{ namespace codegen{
// base // base
nvidia_cu_target* target::as_nvidia() {
return dynamic_cast<nvidia_cu_target*>(this);
}
bool target::is_gpu() const { bool target::is_gpu() const {
return is_gpu_; return is_gpu_;
} }
@@ -25,7 +31,7 @@ void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *m
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) { Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier); Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
return builder.CreateCall(barrier, {}); return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {});
} }
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) { Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
@@ -45,8 +51,7 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
Intrinsic::amdgcn_workgroup_id_y, Intrinsic::amdgcn_workgroup_id_y,
Intrinsic::amdgcn_workgroup_id_z Intrinsic::amdgcn_workgroup_id_z
}; };
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]); Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {});
Value* group_id = builder.CreateCall(get_group_id, {});
return group_id; return group_id;
} }
@@ -99,8 +104,7 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi
Intrinsic::nvvm_read_ptx_sreg_ctaid_y, Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z Intrinsic::nvvm_read_ptx_sreg_ctaid_z
}; };
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]); Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {});
Value* cta_id = builder.CreateCall(get_cta_id, {});
return cta_id; return cta_id;
} }
@@ -120,8 +124,7 @@ Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, un
Intrinsic::nvvm_read_ptx_sreg_nctaid_y, Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
Intrinsic::nvvm_read_ptx_sreg_nctaid_z Intrinsic::nvvm_read_ptx_sreg_nctaid_z
}; };
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]); return builder.CreateIntrinsic(ids[ax], {}, {});
return builder.CreateCall(get_nctaid, {});
} }
// CPU // CPU

View File

@@ -66,7 +66,7 @@ void coalesce::run(ir::module &mod) {
for(size_t id = 0; id < num_groups; id++) { for(size_t id = 0; id < num_groups; id++) {
if(!layout_->get(id)->to_mma884()) if(!layout_->get(id)->to_mma())
continue; continue;
// extract memory stores // extract memory stores
const auto& values = layout_->values_of(id); const auto& values = layout_->values_of(id);

View File

@@ -28,12 +28,14 @@ inline bool is_shmem_res(ir::value* v){
return true; return true;
if(i->get_id() == ir::INST_COPY_TO_SHARED) if(i->get_id() == ir::INST_COPY_TO_SHARED)
return true; return true;
if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
return true;
return false; return false;
} }
// run pass on module // run pass on module
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) { void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
auto *i = dynamic_cast<ir::instruction*>(x); auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction // not an instruction
if(!i) { if(!i) {
@@ -58,8 +60,9 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
// copy // copy
builder.set_insert_point_after(i); builder.set_insert_point_after(i);
ir::value *copy; ir::value *copy;
if(to_shared) if(to_shared){
copy = builder.create_copy_to_shared(x); copy = builder.create_copy_to_shared(x);
}
else else
copy = builder.create_copy_from_shared(x); copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy); parent->replace_uses_of_with(x, copy);

View File

@@ -54,7 +54,7 @@ void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
add_reference(i, res); add_reference(i, res);
} }
void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) { void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) { if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
std::set<ir::value*> incoming; std::set<ir::value*> incoming;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){ for(unsigned n = 0; n < phi->get_num_incoming(); n++){
@@ -63,7 +63,10 @@ void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
if(incoming.insert(inc_val).second){ if(incoming.insert(inc_val).second){
ir::basic_block *block = inc_val->get_parent(); ir::basic_block *block = inc_val->get_parent();
builder.set_insert_point(block->get_inst_list().back()); builder.set_insert_point(block->get_inst_list().back());
builder.create_barrier(); if(type.first)
builder.create_async_wait();
if(type.second)
builder.create_barrier();
} }
} }
} }
@@ -85,8 +88,9 @@ std::pair<membar::interval_vec_t,
membar::interval_vec_t> membar::transfer(ir::basic_block *block, membar::interval_vec_t> membar::transfer(ir::basic_block *block,
const interval_vec_t &written_to, const interval_vec_t &written_to,
const interval_vec_t &read_from, const interval_vec_t &read_from,
std::set<ir::instruction*>& insert_loc, std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
std::set<ir::value*>& safe_war) { std::set<ir::value*>& safe_war,
std::vector<ir::instruction*>& to_sync) {
ir::basic_block::inst_list_t instructions = block->get_inst_list(); ir::basic_block::inst_list_t instructions = block->get_inst_list();
interval_vec_t new_written_to = written_to; interval_vec_t new_written_to = written_to;
interval_vec_t new_read_from = read_from; interval_vec_t new_read_from = read_from;
@@ -95,6 +99,8 @@ std::pair<membar::interval_vec_t,
interval_vec_t read, written; interval_vec_t read, written;
get_read_intervals(i, read); get_read_intervals(i, read);
get_written_intervals(i, written); get_written_intervals(i, written);
if(written.size())
to_sync.push_back(i);
bool read_after_write = intersect(new_written_to, read); bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written); bool write_after_read = intersect(new_read_from, written);
// double buffering // double buffering
@@ -104,9 +110,14 @@ std::pair<membar::interval_vec_t,
} }
// record hazards // record hazards
if(read_after_write || write_after_read) { if(read_after_write || write_after_read) {
insert_loc.insert(i); auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);};
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);};
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async);
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared);
insert_loc.insert({i, {copy_async_wait, barrier}});
new_written_to.clear(); new_written_to.clear();
new_read_from.clear(); new_read_from.clear();
to_sync.clear();
} }
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to)); std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from)); std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
@@ -125,17 +136,17 @@ void membar::run(ir::module &mod) {
if(!layout || !layout->get_double_buffer()) if(!layout || !layout->get_double_buffer())
continue; continue;
for(ir::value *v: layout->get_values()) for(ir::value *v: layout->get_values())
if(v != layout->get_double_buffer()->phi) if(v != layout->get_double_buffer()->phi){
safe_war.insert(v); safe_war.insert(v);
}
} }
for(ir::function *fn: mod.get_function_list()){ for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn); std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, interval_vec_t> written_to; std::map<ir::basic_block*, interval_vec_t> written_to;
std::map<ir::basic_block*, interval_vec_t> read_from; std::map<ir::basic_block*, interval_vec_t> read_from;
std::set<ir::instruction*> insert_locs; std::vector<ir::instruction*> to_sync;
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs;
size_t n_inserted_im1 = 0; size_t n_inserted_im1 = 0;
bool done = false; bool done = false;
do{ do{
@@ -150,7 +161,7 @@ void membar::run(ir::module &mod) {
for(ir::basic_block* pred: block->get_predecessors()) for(ir::basic_block* pred: block->get_predecessors())
pred_read_from.push_back(read_from[pred]); pred_read_from.push_back(read_from[pred]);
// apply transfer function // apply transfer function
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war); auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync);
written_to[block] = result.first; written_to[block] = result.first;
read_from[block] = result.second; read_from[block] = result.second;
} }
@@ -158,8 +169,9 @@ void membar::run(ir::module &mod) {
done = (n_inserted_im1 == n_inserted_i); done = (n_inserted_im1 == n_inserted_i);
n_inserted_im1 = n_inserted_i; n_inserted_im1 = n_inserted_i;
}while(!done); }while(!done);
for(ir::instruction* i: insert_locs) for(auto x: insert_locs){
insert_barrier(i, builder); insert_barrier(x.first, x.second, builder);
}
} }
} }

View File

@@ -97,6 +97,24 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
//} //}
bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){
auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value);
if(!copy_to_shared)
return false;
ir::value *arg = copy_to_shared->get_operand(0);
ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg);
if(!ld)
return false;
builder.set_insert_point(copy_to_shared);
ir::value *ptr = ld->get_pointer_operand();
ir::value *msk = ld->get_mask_operand();
ir::value *val = ld->get_false_value_operand();
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
auto x = dynamic_cast<ir::reduce_inst*>(value); auto x = dynamic_cast<ir::reduce_inst*>(value);
if(!x) if(!x)
@@ -197,10 +215,12 @@ void peephole::run(ir::module &mod) {
continue; continue;
bool was_modified = false; bool was_modified = false;
was_modified = was_modified || rewrite_mult(i, builder); was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder); // was_modified = was_modified || rewrite_cts_cfs(i, builder);
was_modified = was_modified || rewrite_trans_phi(i, builder); was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
// if(tgt_->as_nvidia()->sm() >= 80)
// was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(was_modified) if(was_modified)
seen.insert(i); seen.insert(i);
} }

View File

@@ -0,0 +1,51 @@
#include <iostream>
#include <algorithm>
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/codegen/transform/reorder.h"
namespace triton {
namespace codegen{
namespace transform{
void reorder::run(ir::module& mod){
ir::builder &builder = mod.get_builder();
std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
ir::value* _ptr = ld->get_pointer_operand();
ir::value* _msk = ld->get_mask_operand();
ir::value* _val = ld->get_false_value_operand();
auto ptr = std::find(block->begin(), block->end(), _ptr);
auto msk = std::find(block->begin(), block->end(), _msk);
auto val = std::find(block->begin(), block->end(), _val);
if(ptr == block->end() || msk == block->end() || val == block->end())
continue;
auto it = std::find(block->begin(), block->end(), i);
int dist_ptr = std::distance(ptr, it);
int dist_msk = std::distance(msk, it);
int dist_val = std::distance(val, it);
if(dist_ptr < dist_msk && dist_ptr < dist_val)
builder.set_insert_point(++ptr);
if(dist_msk < dist_ptr && dist_msk < dist_val)
builder.set_insert_point(++msk);
if(dist_val < dist_ptr && dist_val < dist_msk)
builder.set_insert_point(++val);
ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
to_replace.push_back(std::make_pair(ld, new_ld));
}
}
for(auto& x: to_replace)
x.first->replace_all_uses_with(x.second);
}
}
}
}

View File

@@ -48,46 +48,6 @@ std::unique_ptr<codegen::target> host_device::make_target() const {
// CUDA // // CUDA //
/* ------------------------ */ /* ------------------------ */
// architecture
cu_device::Architecture cu_device::nv_arch(std::pair<unsigned int, unsigned int> sm) const {
switch(sm.first) {
case 7:
switch(sm.second){
case 0: return Architecture::SM_7_0;
}
case 6:
switch(sm.second){
case 0: return Architecture::SM_6_0;
case 1: return Architecture::SM_6_1;
}
case 5:
switch(sm.second){
case 0: return Architecture::SM_5_0;
case 2: return Architecture::SM_5_2;
default: return Architecture::UNKNOWN;
}
case 3:
switch(sm.second){
case 0: return Architecture::SM_3_0;
case 5: return Architecture::SM_3_5;
case 7: return Architecture::SM_3_7;
default: return Architecture::UNKNOWN;
}
case 2:
switch(sm.second){
case 0: return Architecture::SM_2_0;
case 1: return Architecture::SM_2_1;
default: return Architecture::UNKNOWN;
}
default: return Architecture::UNKNOWN;
}
}
// information query // information query
template<CUdevice_attribute attr> template<CUdevice_attribute attr>
int cu_device::cuGetInfo() const{ int cu_device::cuGetInfo() const{
@@ -108,11 +68,6 @@ nvmlDevice_t cu_device::nvml_device() const{
return map.at(key); return map.at(key);
} }
// architecture
cu_device::Architecture cu_device::architecture() const{
return nv_arch(compute_capability());
}
// number of address bits // number of address bits
size_t cu_device::address_bits() const{ size_t cu_device::address_bits() const{
return sizeof(size_t)*8; return sizeof(size_t)*8;
@@ -133,17 +88,17 @@ std::string cu_device::pci_bus_id() const{
} }
// force the device to be interpreted as a particular cc // force the device to be interpreted as a particular cc
void cu_device::interpret_as(std::pair<size_t, size_t> cc){ void cu_device::interpret_as(int cc){
interpreted_as_ = std::make_shared<std::pair<size_t, size_t>>(cc); interpreted_as_ = std::make_shared<int>(cc);
} }
// compute capability // compute capability
std::pair<size_t, size_t> cu_device::compute_capability() const { int cu_device::compute_capability() const {
if(interpreted_as_) if(interpreted_as_)
return *interpreted_as_; return *interpreted_as_;
size_t _major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(); size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
size_t _minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(); size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
return std::make_pair(_major, _minor); return major*10 + minor;
} }
// maximum number of threads per block // maximum number of threads per block
@@ -218,7 +173,7 @@ std::string cu_device::infos() const{
// target // target
std::unique_ptr<codegen::target> cu_device::make_target() const { std::unique_ptr<codegen::target> cu_device::make_target() const {
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target()); return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target(compute_capability()));
} }

View File

@@ -93,6 +93,7 @@ namespace driver
bool dispatch::cuinit(){ bool dispatch::cuinit(){
if(cuda_==nullptr){ if(cuda_==nullptr){
putenv((char*)"CUDA_CACHE_DISABLE=1");
std::string libcuda = tools::getenv("TRITON_LIBCUDA"); std::string libcuda = tools::getenv("TRITON_LIBCUDA");
if(libcuda.empty()) if(libcuda.empty())
cuda_ = dlopen("libcuda.so", RTLD_LAZY); cuda_ = dlopen("libcuda.so", RTLD_LAZY);

View File

@@ -20,7 +20,9 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/ */
#include <fstream> #include <fstream>
#include <unistd.h>
#include <memory> #include <memory>
#include <regex>
#include "triton/driver/module.h" #include "triton/driver/module.h"
#include "triton/driver/context.h" #include "triton/driver/context.h"
#include "triton/driver/error.h" #include "triton/driver/error.h"
@@ -41,6 +43,19 @@
#include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Cloning.h"
std::string exec(const char* cmd) {
std::array<char, 128> buffer;
std::string result;
std::unique_ptr<FILE, decltype(&pclose)> pipe(popen(cmd, "r"), pclose);
if (!pipe) {
throw std::runtime_error("popen() failed!");
}
while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) {
result += buffer.data();
}
return result;
}
namespace triton namespace triton
{ {
namespace driver namespace driver
@@ -63,11 +78,11 @@ void module::init_llvm() {
} }
module::module(CUmodule mod, bool has_ownership) module::module(CUmodule mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership) { : polymorphic_resource(mod, has_ownership), spilled_(0) {
} }
module::module(host_module_t mod, bool has_ownership) module::module(host_module_t mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership) { : polymorphic_resource(mod, has_ownership), spilled_(0) {
} }
@@ -86,10 +101,12 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
file_type_t ft) { file_type_t ft) {
init_llvm(); init_llvm();
// // debug // // debug
// llvm::legacy::PassManager pm; llvm::legacy::PassManager pm;
std::string tmp;
// llvm::raw_string_ostream oss(llir_);
// pm.add(llvm::createPrintModulePass(llvm::outs())); // pm.add(llvm::createPrintModulePass(llvm::outs()));
// pm.add(llvm::createVerifierPass()); pm.add(llvm::createVerifierPass());
// pm.run(*module); pm.run(*module);
// create machine // create machine
module->setTargetTriple(triple); module->setTargetTriple(triple);
std::string error; std::string error;
@@ -176,7 +193,7 @@ host_module::host_module(std::unique_ptr<llvm::Module> src): module(host_module_
// create execution engine // create execution engine
for(llvm::Function& fn: src->functions()) for(llvm::Function& fn: src->functions())
hst_->functions[fn.getName()] = &fn; hst_->functions[fn.getName().str()] = &fn;
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost(); // llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
// auto DL = JTMB.getDefaultDataLayoutForTarget(); // auto DL = JTMB.getDefaultDataLayoutForTarget();
@@ -225,7 +242,8 @@ static std::map<int, int> vptx = {
{10010, 64}, {10010, 64},
{10020, 65}, {10020, 65},
{11000, 70}, {11000, 70},
{11010, 71} {11010, 71},
{11020, 72}
}; };
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) { std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
@@ -238,9 +256,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
assert(short_ptr); assert(short_ptr);
short_ptr->setValue(true); short_ptr->setValue(true);
// compute capability // compute capability
auto _cc = ((driver::cu_device*)device)->compute_capability(); int cc = ((driver::cu_device*)device)->compute_capability();
int cc = _cc.first*10 + _cc.second;
cc = std::min(cc, max_nvvm_cc);
std::string sm = "sm_" + std::to_string(cc); std::string sm = "sm_" + std::to_string(cc);
// driver version // driver version
int version; int version;
@@ -251,12 +267,11 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
throw std::runtime_error("Triton requires CUDA 10+"); throw std::runtime_error("Triton requires CUDA 10+");
// PTX version // PTX version
int ptx = vptx.at(version); int ptx = vptx.at(version);
ptx = std::min(ptx, max_nvvm_ptx);
int ptx_major = ptx / 10; int ptx_major = ptx / 10;
int ptx_minor = ptx % 10; int ptx_minor = ptx % 10;
// create // create
llvm::SmallVector<char, 0> buffer; llvm::SmallVector<char, 0> buffer;
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "+ptx" + std::to_string(ptx), Assembly); module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", "sm_" + std::to_string(std::min(cc, max_nvvm_cc)), "", buffer, "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)), Assembly);
std::string result(buffer.begin(), buffer.end()); std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n"); find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n"); find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
@@ -266,21 +281,69 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
} }
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(compile_llvm_module(std::move(ll_module), device)) { } cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(device, compile_llvm_module(std::move(ll_module), device)) { }
cu_module::cu_module(std::string const & source) : module(CUmodule(), true), source_(source){ cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){
// JIT compile source-code // JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096;
std::string errbuf(errbufsize, 0);
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)errbuf.data()};
try{ try{
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval); // // compile ptx with ptxas
}catch(exception::cuda::invalid_ptx const &){ // char _fsrc[] = "/tmp/triton_k_XXXXXX";
// char _flog[] = "/tmp/triton_l_XXXXXX";
// int fdsrc = mkstemp(_fsrc);
// int fdlog = mkstemp(_flog);
// std::string fsrc = _fsrc;
// std::string flog = _flog;
// std::ofstream ofs(fsrc);
// ofs << source;
// ofs.close();
// std::string cmd;
// int err;
// driver::cu_device* cu_device = (driver::cu_device*)device;
// cmd = "ptxas -v --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
// err = system(cmd.c_str());
// dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
// std::ifstream file(flog);
// std::string log;
// if(file)
// while (!file.eof()) log.push_back(file.get());
// unlink(_fsrc);
// unlink(_flog);
// std::smatch match;
// std::regex expr ("\\b([0-9]+) bytes spill");
// spilled_ = 0;
// while (std::regex_search (log,match,expr)){
// spilled_ += std::stoi(match[1]);
// log = match.suffix();
// }
// std::cout << log << std::endl;
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
CU_JIT_LOG_VERBOSE};
unsigned int errbufsize = 8192;
unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval);
std::string err(_err);
std::string log(_log);
// std::cout << log << std::endl;
std::smatch match;
std::regex expr ("\\b([0-9]+) bytes spill");
spilled_ = 0;
while (std::regex_search(log,match,expr)){
spilled_ += std::stoi(match[1]);
log = match.suffix();
}
}
catch(exception::cuda::invalid_ptx const &){
//#ifdef TRITON_LOG_PTX_ERROR //#ifdef TRITON_LOG_PTX_ERROR
std::cout << source << std::endl; std::cout << source << std::endl;
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
std::cerr << errbuf << std::endl;
// exit(1); // exit(1);
//#endif //#endif
throw; throw;

View File

@@ -1,5 +1,6 @@
#include <string> #include <string>
#include <algorithm> #include <algorithm>
#include <iostream>
#include "triton/ir/basic_block.h" #include "triton/ir/basic_block.h"
#include "triton/ir/builder.h" #include "triton/ir/builder.h"
#include "triton/ir/constant.h" #include "triton/ir/constant.h"
@@ -253,6 +254,15 @@ DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
value *builder::create_load(value *ptr, const std::string &name){ value *builder::create_load(value *ptr, const std::string &name){
return insert(unmasked_load_inst::create(ptr, name)); return insert(unmasked_load_inst::create(ptr, name));
// type *ty = ptr->get_type()->get_pointer_element_ty();
// value *mask = constant_int::get(get_int1_ty(), 1);
// value *undef = undef_value::get(ty);
// if(ptr->get_type()->is_tile_ty()){
// auto shapes = ptr->get_type()->get_tile_shapes();
// return insert(masked_load_inst::create(ptr, create_splat(mask, shapes), create_splat(undef, shapes), name));
// }
// return insert(masked_load_inst::create(ptr, mask, undef, name));
} }
value *builder::create_store(value *ptr, value *val, const std::string &name){ value *builder::create_store(value *ptr, value *val, const std::string &name){
@@ -263,6 +273,7 @@ value *builder::create_masked_load(value *ptr, value *mask, value *false_value,
return insert(masked_load_inst::create(ptr, mask, false_value, name)); return insert(masked_load_inst::create(ptr, mask, false_value, name));
} }
value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){ value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){
return insert(masked_store_inst::create(ptr, val, mask, name)); return insert(masked_store_inst::create(ptr, val, mask, name));
} }
@@ -348,13 +359,22 @@ value *builder::create_copy_to_shared(value *arg, const std::string &name) {
return insert(copy_to_shared_inst::create(arg, name)); return insert(copy_to_shared_inst::create(arg, name));
} }
value *builder::create_copy_from_shared(value *arg, const std::string &name) { value *builder::create_copy_from_shared(value *arg, const std::string &name) {
return insert(copy_from_shared_inst::create(arg, name)); return insert(copy_from_shared_inst::create(arg, name));
} }
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, const std::string &name) {
return insert(masked_load_async_inst::create(ptr, mask, false_value, name));
}
value *builder::create_barrier(const std::string &name) { value *builder::create_barrier(const std::string &name) {
return insert(barrier_inst::create(ctx_, name)); return insert(barrier_inst::create(ctx_, name));
} }
value *builder::create_async_wait() {
return insert(async_wait_inst::create(ctx_));
}
} }
} }

View File

@@ -463,6 +463,20 @@ masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false
return new masked_load_inst(ptr, mask, false_value, name, next); return new masked_load_inst(ptr, mask, false_value, name, next);
} }
// masked load async
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
const std::string &name, instruction *next)
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, name, next) {
set_operand(0, ptr);
set_operand(1, mask);
set_operand(2, false_value);
}
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
const std::string &name, instruction *next) {
return new masked_load_async_inst(ptr, mask, false_value, name, next);
}
// atomic add // atomic add
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next) atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
@@ -804,6 +818,14 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
return new barrier_inst(ctx, name, next); return new barrier_inst(ctx, name, next);
} }
async_wait_inst::async_wait_inst(context &ctx, const std::string &name,
instruction *next)
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next) { }
async_wait_inst* async_wait_inst::create(context &ctx, const std::string &name, instruction *next) {
return new async_wait_inst(ctx, name, next);
}
// nv_dynamic_program_idx // nv_dynamic_program_idx
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next) make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)

View File

@@ -65,7 +65,12 @@ void print(module &mod, std::ostream& os) {
os << get_name(ops[i], cnt++); os << get_name(ops[i], cnt++);
os << (i < num_ops - 1?", ":""); os << (i < num_ops - 1?", ":"");
} }
os << ";" << std::endl; os << ";";
// os << " (";
// for(ir::user* usr: inst->get_users())
// os << get_name(usr, cnt++) << ", " ;
// os << " )";
os << std::endl;
} }
} }
os << "}" << std::endl; os << "}" << std::endl;

View File

@@ -68,9 +68,10 @@ unsigned user::get_num_hidden() const {
value::users_t::iterator user::replace_uses_of_with(value *before, value *after) { value::users_t::iterator user::replace_uses_of_with(value *before, value *after) {
for(size_t i = 0; i < ops_.size(); i++) for(size_t i = 0; i < ops_.size(); i++)
if(ops_[i] == before) if(ops_[i] == before){
ops_[i] = after; ops_[i] = after;
after->add_use(this); after->add_use(this);
}
return before->erase_use(this); return before->erase_use(this);
} }

View File

@@ -56,10 +56,13 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
return set_ret(bld_->create_dot(lhs, rhs, _0)); return set_ret(bld_->create_dot(lhs, rhs, _0));
} }
case Token::MASKED_DEREF: { case Token::MASKED_DEREF: {
// TODO: FIXME
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_); ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty()); ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty());
auto it = bld_->get_insert_block();
if(ret_ty->is_tile_ty()) if(ret_ty->is_tile_ty())
false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes()); false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes());
bld_->set_insert_point(it);
return set_ret(bld_->create_masked_load(rhs, lhs, false_value)); return set_ret(bld_->create_masked_load(rhs, lhs, false_value));
} }
case Token::ELLIPSIS: { case Token::ELLIPSIS: {
@@ -274,9 +277,7 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) { if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty()) if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes()); false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val);
cond,
false_val);
ld->replace_all_uses_with(new_ld); ld->replace_all_uses_with(new_ld);
ld->erase_from_parent(); ld->erase_from_parent();
return set_ret(new_ld); return set_ret(new_ld);
@@ -468,10 +469,10 @@ void Generator::VisitForStmt(ForStmt *forStmt) {
}); });
if(init_) if(init_)
VisitStmt(init_); VisitStmt(init_);
// VisitExpr(cond_); VisitExpr(cond_);
// ir::value *cond = ret_; ir::value *cond = ret_;
// bld_->create_cond_br(cond, loop_bb, next_bb); bld_->create_cond_br(cond, loop_bb, next_bb);
bld_->create_br(loop_bb); // bld_->create_br(loop_bb);
bld_->set_insert_point(loop_bb); bld_->set_insert_point(loop_bb);
if(body_) if(body_)
VisitStmt(body_); VisitStmt(body_);

View File

@@ -1,4 +1,4 @@
#include <string> #include <string>
#include <mutex> #include <mutex>
#include <regex> #include <regex>
#include <functional> #include <functional>
@@ -9,11 +9,13 @@
#include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/liveness.h" #include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/transform/coalesce.h" #include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/transform/dce.h" #include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/reassociate.h" #include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/reorder.h"
#include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/disassociate.h" #include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/selection/generator.h" #include "triton/codegen/selection/generator.h"
@@ -29,6 +31,7 @@
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/print.h" #include "triton/ir/print.h"
#include "triton/runtime/error.h"
#include "triton/tools/bench.hpp" #include "triton/tools/bench.hpp"
#include "triton/tools/sha1.hpp" #include "triton/tools/sha1.hpp"
#include "triton/tools/sys/getenv.hpp" #include "triton/tools/sys/getenv.hpp"
@@ -67,7 +70,7 @@ void _loop_nest(std::vector<size_t> const & ranges,
/* OPTIONS */ /* OPTIONS */
/* --------------------- */ /* --------------------- */
std::string function::options_t::to_str() const{ std::string options_t::to_str() const{
std::string ret = "nw-" + std::to_string(num_warps); std::string ret = "nw-" + std::to_string(num_warps);
for(const auto& x : defines){ for(const auto& x : defines){
ret += '-'; ret += '-';
@@ -110,41 +113,41 @@ arg_type convert(ir::type *ty) {
throw std::runtime_error("unknown type"); throw std::runtime_error("unknown type");
} }
void function::caller::write(std::ofstream &ofs) { //void function::caller::write(std::ofstream &ofs) {
// write name // // write name
ofs << name_ << std::endl; // ofs << name_ << std::endl;
// write signature // // write signature
for(size_t i = 0; i < param_tys_.size(); i++) // for(size_t i = 0; i < param_tys_.size(); i++)
ofs << param_tys_[i] << " "; // ofs << param_tys_[i] << " ";
ofs << std::endl; // ofs << std::endl;
// write module // // write module
std::string source = ((driver::cu_module*)(&*parent_))->source(); // std::string source = ((driver::cu_module*)(&*parent_))->ptx();
ofs << source; // ofs << source;
} //}
void function::caller::read(std::ifstream &ifs) { //void function::caller::read(driver::context* ctx, std::ifstream &ifs) {
// read name // // read name
std::getline(ifs, name_); // std::getline(ifs, name_);
// read signature // // read signature
std::string line; // std::string line;
std::getline(ifs, line); // std::getline(ifs, line);
std::istringstream current(line); // std::istringstream current(line);
int param; // int param;
param_tys_.clear(); // param_tys_.clear();
while(current >> param) // while(current >> param)
param_tys_.push_back((arg_type)param); // param_tys_.push_back((arg_type)param);
// read module // // read module
std::string src((std::istreambuf_iterator<char>(ifs)), // std::string src((std::istreambuf_iterator<char>(ifs)),
std::istreambuf_iterator<char>()); // std::istreambuf_iterator<char>());
parent_.reset(new driver::cu_module(src)); // parent_.reset(new driver::cu_module(ctx, src));
bin_.reset(driver::kernel::create(&*parent_, name_.c_str())); // bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
} //}
function::caller::caller(std::ifstream &ifs, const options_t& opt) //function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt)
: opt_(opt) { // : opt_(opt) {
read(ifs); // read(ctx, ifs);
} //}
function::caller::caller(ir::function *ir, function::caller::caller(ir::function *ir,
std::shared_ptr<driver::module> parent, const options_t& opt) std::shared_ptr<driver::module> parent, const options_t& opt)
@@ -198,20 +201,23 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
// generate llvm code // generate llvm code
llvm::LLVMContext ctx; llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx)); std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
// optimizations
bool cts_use_async = target->as_nvidia()->sm() >= 80;
// create passes // create passes
codegen::analysis::align align; codegen::analysis::align align;
codegen::analysis::axes axes; codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::disassociate disassociate; codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get()); codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
codegen::analysis::liveness liveness(&layouts); codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target.get());
codegen::analysis::allocation allocation(&liveness); codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation); codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce; codegen::transform::dce dce;
codegen::transform::peephole peephole; codegen::transform::peephole peephole(target.get());
codegen::transform::reassociate reassociate; codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::cts cts; codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
// run passes // run passes
dce.run(module); dce.run(module);
disassociate.run(module); disassociate.run(module);
@@ -233,17 +239,20 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
} }
peephole.run(module); peephole.run(module);
dce.run(module); dce.run(module);
// ir::print(module, std::cout);
align.run(module); align.run(module);
axes.run(module); axes.run(module);
layouts.run(module); layouts.run(module);
swizzle.run(module);
liveness.run(module); liveness.run(module);
allocation.run(module); allocation.run(module);
if(allocation.allocated_size() > device->max_shared_memory()) if(allocation.allocated_size() > device->max_shared_memory())
throw std::runtime_error("using too much shared memory"); throw exception::out_of_shared_memory();
barriers.run(module); barriers.run(module);
// ir::print(module, std::cout);
isel.visit(module, *llvm); isel.visit(module, *llvm);
std::unique_ptr<driver::module> res(driver::module::create(device, std::move(llvm))); std::unique_ptr<driver::module> res(driver::module::create(device, std::move(llvm)));
if(res->spilled() > 256)
throw exception::out_of_registers();
return res; return res;
} }
@@ -265,11 +274,11 @@ void function::make(driver::device *device, options_t opt) {
auto ir = make_ir(parser); auto ir = make_ir(parser);
// triton-ir -> binary // triton-ir -> binary
std::unique_ptr<driver::module> bin; std::unique_ptr<driver::module> bin;
// try{ try{
bin = make_bin(*ir, device, opt); bin = make_bin(*ir, device, opt);
// }catch(const std::runtime_error&){ }catch(const exception::base&){
// return nullptr; throw;
// } }
// create callable // create callable
ir::function *tmp = ir->get_function_list()[0]; ir::function *tmp = ir->get_function_list()[0];
callers_[opt].reset(new caller(tmp, std::move(bin), opt)); callers_[opt].reset(new caller(tmp, std::move(bin), opt));
@@ -283,6 +292,7 @@ void function::precompile(driver::device* device, const options_space_t& space)
for(const auto& x: space.defines) for(const auto& x: space.defines)
ranges.push_back(x.second.size()); ranges.push_back(x.second.size());
// functor for source with given option // functor for source with given option
std::map<options_t, std::string> err;
auto do_make = [&](std::vector<size_t> params) { auto do_make = [&](std::vector<size_t> params) {
// compilation options // compilation options
unsigned i = 0; unsigned i = 0;
@@ -291,20 +301,73 @@ void function::precompile(driver::device* device, const options_space_t& space)
for(auto D: space.defines) for(auto D: space.defines)
opt.defines[D.first] = D.second[params[i++]]; opt.defines[D.first] = D.second[params[i++]];
// compile // compile
make(device, opt); try{
make(device, opt);
}catch(const exception::base& e){
err[opt] = e.what();
}
}; };
// multi-threaded compilation // multi-threaded compilation
_loop_nest(ranges, do_make); _loop_nest(ranges, do_make);
if(callers_.empty()) if(callers_.empty()){
throw std::runtime_error("could not compile kernel"); std::ostringstream dbg;
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
for(auto x: err){
dbg << "[ ";
dbg << x.first.num_warps << ", ";
dbg << "{ ";
for(const auto& y: x.first.defines)
dbg << '"' << y.first << "\"= \"" << y.second << "\", ";
dbg << " } ] -> " << x.second << std::endl;
}
throw exception::no_valid_configuration(dbg.str());
}
} }
std::string function::ptx(driver::device* device, const options_t& opt) { std::string function::get_asm(asm_mode_t mode, driver::device* device, const options_t& opt) {
make(device, opt); make(device, opt);
const auto& fn = callers_.at(opt); const auto& fn = callers_.at(opt);
if(!fn) if(!fn)
return ""; return "";
return ((driver::cu_module*)fn->parent())->source(); switch(mode){
case ASM_LLIR:{
return fn->parent()->llir();
}
case ASM_NV_PTX:
case ASM_NV_SASS:{
std::string ptx = ((driver::cu_module*)fn->parent())->ptx();
// SASS
std::string input = std::tmpnam(nullptr);
std::string output = std::tmpnam(nullptr);
std::ofstream ofs(input);
ofs << ptx;
ofs.close();
if(mode == ASM_NV_PTX)
return ptx;
std::string cmd;
int err;
// compile ptx
driver::cu_device* cu_device = (driver::cu_device*)device;
cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o";
err = system(cmd.c_str());
// disassemble
cmd = "cuobjdump --dump-sass " + input + ".o >> " + output;
err = system(cmd.c_str());
std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/");
std::string to_delete = " /*";
std::ifstream ifs(output);
std::string line;
std::string sass;
while(std::getline(ifs, line))
if(!std::regex_match(line, comment))
sass += line + "\n";
return sass;
}
default:
return "";
}
} }
// returns program with best compilation options for given parameter // returns program with best compilation options for given parameter

View File

@@ -1,56 +0,0 @@
import triton
import numpy as np
from enum import Enum
class MODE(Enum):
TF = 1
TORCH = 2
try:
import tensorflow as tf
mode = MODE.TF
except ModuleNotFoundError:
pass
try:
import torch
mode = MODE.TORCH
except ModuleNotFoundError:
pass
C, H, W, B = 32, 1, 1, 128
x = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
gamma = np.random.uniform(-1, 1, C).astype(np.float32)
beta = np.random.uniform(-1, 1, C).astype(np.float32)
dy = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
if mode == MODE.TORCH:
fw_x = torch.from_numpy(x).cuda()
fw_gamma = torch.from_numpy(gamma).cuda()
fw_beta = torch.from_numpy(beta).cuda()
fw_dy = torch.from_numpy(dy).cuda()
# register gradients
fw_x.requires_grad_(True)
fw_gamma.requires_grad_(True)
fw_beta.requires_grad_(True)
# execute
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
fw_y.backward(fw_dy)
if mode == MODE.TF:
fw_x = tf.placeholder(shape=x.shape, dtype=x.dtype)
fw_gamma = tf.placeholder(shape=gamma.shape, dtype=gamma.dtype)
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
# execute
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
# run
sess = tf.InteractiveSession()
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
sess.run(tf.global_variables_initializer())
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
print(result)

View File

@@ -1,213 +0,0 @@
import triton
import torch
from torch.utils.cpp_extension import load
import numpy as np
#import utils
from time import time
torch.manual_seed(0)
#torch.backends.cudnn.benchmark = True
configs = []
# Matrix multiplication
MNK = [
(512, 512 ,512),
(2048, 2048, 2048),
#(8192, 8192, 8192),
(64, 64, 64000),
(64, 64, 128000),
(256, 256, 64000),
(256, 256, 128000),
(1536, 16, 1536),
(1536, 32, 1536),
(1536, 64, 1536),
# (1536, 128, 1536),
# (4096, 16, 4096),
# (4096, 32, 4096),
# (4096, 64, 4096),
# (4096, 128, 4096),
# (127008, 768, 576)
]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict(), None, None, None)]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a.t(), b)
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict(), None, None, None)]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b.t())
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict(), None, None, None)]
# Relative attention
NTHSE = [
(16, 512, 1, 64, 64),
# (16, 512, 1, 128, 128),
# (16, 512, 1, 256, 256),
# (16, 512, 1, 256, 512),
(16, 512, 8, 64, 64),
# (16, 512, 8, 128, 128),
# (16, 512, 8, 256, 256),
# (16, 512, 8, 256, 512),
# (64, 1024, 1, 64, 64),
(64, 1024, 1, 128, 128),
# (64, 1024, 1, 256, 256),
# (64, 1024, 1, 256, 512),
# (64, 1024, 8, 64, 64),
(64, 1024, 8, 128, 128),
# (64, 1024, 8, 256, 256),
# (64, 1024, 8, 256, 512),
# (128, 1024, 1, 64, 64),
# (128, 1024, 1, 128, 128),
# (128, 1024, 1, 256, 256),
(128, 1024, 1, 256, 512),
# (128, 1024, 8, 64, 64),
# (128, 1024, 8, 128, 128),
# (128, 1024, 8, 256, 256),
#(128, 1024, 8, 256, 512)
]
#for N, T, H, S, E in NTHSE:
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict(), None, None, None)]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict(), None, None, None)]
# 1D Dense convolution
NCHKR = [
#(1, 1152, 12602, 512, 3)
]
for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
configs += [([N, C, H],
[C, R, K],
[N, K, H - R + 1],
torch_fn,
'nc(h+r),crk->nkh',
dict(), None, None, None)]
# 2D Dense convolution
NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3),
#(128, 3, 32, 32, 64, 3, 3),
#(1, 1024, 32, 112, 112, 1024, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, G, H, W, K, R, S in NCHWKRS:
stride = 2
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), stride=stride, groups=G)
P = (H - R + 1) // stride
Q = (W - S + 1) // stride
transform_a = lambda a: a.view(N, G, C // G, H, W)
transform_b = lambda b: b.view(C // G, R, S, G, K // G)
transform_c = lambda c: c.view(N, K, P, Q)
configs += [([N, C, H, W],
[C // G, R, S, K],
[N, G, K // G, P, Q],
torch_fn,
'ngc(h*2+r)(w*2+s),crsgk->ngkhw',
dict(), transform_a, transform_b, transform_c)]
# 3D Dense Convolution
NCDHWKTRS = [
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
]
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
configs += [([N, C, D, H, W],
[C, T, R, S, K],
[N, K, D - T + 1, H - R + 1, W - R + 1],
torch_fn,
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
dict(), None, None, None)]
# Shift convolution
shift_cuda = torch.utils.cpp_extension.load(
'shift_cuda', ['kernels/shift_cuda.cpp',
'kernels/shift_cuda_kernel.cu'],
extra_cflags=['-O3'])
class shift(torch.autograd.Function):
@staticmethod
def forward(ctx, x, shift):
ctx.save_for_backward(shift)
return shift_cuda.forward(x, shift)
@staticmethod
def backward(ctx, grad_output):
shift, = ctx.saved_tensors
grad_output = shift_cuda.backward(grad_output, shift)
return grad_output, None
NCHWKRS = [
#(8, 64, 128, 128, 128, 3, 3),
#(8, 128, 64, 64, 256, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
def shift_conv(a, b, **kwargs):
shift_h, shift_w = kwargs['sh'], kwargs['sw']
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
shift_torch = torch.from_numpy(shift_torch).cuda()
a = shift.apply(a, shift_torch)
b = b.permute(1, 0)
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
return torch.nn.functional.conv2d(a, b)
configs += [([N, C, H, W],
[C, K],
[N, K, H, W],
shift_conv,
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
{'sh': shift_h, 'sw': shift_w},
None, None, None)]
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays, \
transform_a, transform_b, transform_c in configs:
dtype = torch.cuda.FloatTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# reference output
if torch_fn:
rc = torch_fn(a, b, **arrays)
else:
rc = torch.einsum(expr, a, b)
# triton output
ta = a if transform_a is None else transform_a(a)
tb = b if transform_b is None else transform_b(b)
tc = torch.empty(c_shape, device=a.device)
triton.ops.einsum(expr, ta, tb, tc, arrays = arrays, bench = True)
ctx = triton.ops._einsum.registry[tc]
tc = tc if transform_c is None else transform_c(tc)
# performance relative to equivalent matrix multiplication
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
cmp_eqbmm = True
if cmp_eqbmm:
a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda()
c = torch.empty((B, M, N), device=a.device).cuda()
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, c, bench = True)
ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms
cmp_str = f'({ratio:4.2f})'
else:
cmp_str = ''
# test and benchmark
bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3
diff = (tc - rc).abs().max() / rc.abs().max()
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')

View File

@@ -1,42 +0,0 @@
#include <torch/torch.h>
#include <vector>
// CUDA forward declarations
at::Tensor shift_cuda_forward(
const at::Tensor input,
const at::Tensor shift);
at::Tensor shift_cuda_backward(
const at::Tensor grad_input,
const at::Tensor shift);
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor shift_forward(
const at::Tensor input,
const at::Tensor shift) {
CHECK_INPUT(input);
CHECK_INPUT(shift);
return shift_cuda_forward(input, shift);
}
at::Tensor shift_backward(
const at::Tensor grad_input,
const at::Tensor shift) {
CHECK_INPUT(grad_input);
CHECK_INPUT(shift);
return shift_cuda_backward(grad_input, shift);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &shift_forward, "Shift forward (CUDA)");
m.def("backward", &shift_backward, "Shift backward (CUDA)");
}

View File

@@ -1,111 +0,0 @@
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
namespace {
template <typename scalar_t>
__global__ void shift_cuda_forward_kernel(
const scalar_t* __restrict__ input,
const int32_t* __restrict__ shift,
scalar_t* __restrict__ output,
const int32_t B,
const int32_t C,
const int32_t H,
const int32_t W) {
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t size = B*C*H*W;
const int32_t CHW = C*H*W;
const int32_t HW = H*W;
const int32_t b = idx / CHW;
const int32_t c = (idx - b*CHW) / HW;
const int32_t h = (idx - b*CHW - c*HW) / W;
const int32_t w = idx - b*CHW - c*HW - h*W;
const int32_t target_w = w + shift[2*c];
const int32_t target_h = h + shift[2*c + 1];
const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w;
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
output[target_idx] = input[idx];
}
}
template <typename scalar_t>
__global__ void shift_cuda_backward_kernel(
const scalar_t* __restrict__ grad_input,
scalar_t* __restrict__ grad_output,
const int32_t* __restrict__ shift,
const int32_t B,
const int32_t C,
const int32_t W,
const int32_t H) {
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t size = B*C*W*H;
const int32_t CWH = C*W*H;
const int32_t WH = W*H;
const int32_t b = idx / CWH;
const int32_t c = (idx - b*CWH) / WH;
const int32_t w = (idx - b*CWH - c*WH) / W;
const int32_t h = idx - b*CWH - c*WH - w*H;
const int32_t target_w = w - shift[2*c];
const int32_t target_h = h - shift[2*c + 1];
const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h;
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
grad_output[target_idx] = grad_input[idx];
}
}
} // namespace
at::Tensor shift_cuda_forward(
const at::Tensor input,
const at::Tensor shift) {
const auto B = input.size(0);
const auto C = input.size(1);
const auto H = input.size(2);
const auto W = input.size(3);
const auto size = B*C*W*H;
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
auto output = at::zeros_like(input);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] {
shift_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
input.data<scalar_t>(),
shift.data<int32_t>(),
output.data<scalar_t>(),
B,
C,
H,
W);
}));
return output;
}
at::Tensor shift_cuda_backward(
const at::Tensor grad_input,
const at::Tensor shift) {
const auto B = grad_input.size(0);
const auto C = grad_input.size(1);
const auto H = grad_input.size(2);
const auto W = grad_input.size(3);
const auto size = B*C*W*H;
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
auto grad_output = at::zeros_like(grad_input);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] {
shift_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
grad_input.data<scalar_t>(),
grad_output.data<scalar_t>(),
shift.data<int32_t>(),
B,
C,
H,
W);
}));
return grad_output;
}

View File

@@ -1,109 +0,0 @@
import triton
import numpy
import torch
import itertools
torch.manual_seed(0)
numpy.random.seed(0)
def to_sparse(expr, data, layout, shape, block):
# shape of result
sparse = None
shape_ret = []
for i, d in enumerate(expr):
if d.isupper() and sparse is None:
sparse = i
shape_ret.append(int(layout.sum()))
if d.isupper():
shape_ret.append(block[d])
else:
shape_ret.append(shape[i])
# iterator
steps = [block[d] if d.isupper() else 1 for d in expr]
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
# create result
ret = torch.empty(*shape_ret, dtype=data.dtype, device=data.device)
blockid = 0
nzblockid = 0
for curr in itertools.product(*it):
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
blockid = 0
nzblockid = 0
data_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
ret_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
ret_slice.insert(sparse, nzblockid)
if int(layout.view(-1)[blockid]):
ret[ret_slice] = data[data_slice]
nzblockid += 1
blockid += 1
return ret
def to_dense(expr, data, layout, shape, block):
sparse = None
for i, d in enumerate(expr):
if d.isupper() and sparse is None:
sparse = i
ret = torch.zeros(*shape, dtype=data.dtype, device=data.device)
steps = [block[d] if d.isupper() else 1 for d in expr]
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
blockid = 0
nzblockid = 0
for curr in itertools.product(*it):
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
blockid = 0
nzblockid = 0
ret_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
data_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
data_slice.insert(sparse, nzblockid)
if int(layout.view(-1)[blockid]):
ret[ret_slice] = data[data_slice]
nzblockid += 1
blockid += 1
return ret
def test_expr(expr, shape, blocks):
# decompose expr
expr_a, expr_bc = expr.split(",")
expr_b, expr_c = expr_bc.split("->")
# check with argument is sparse
sparse_a = any(x.isupper() for x in expr_a)
sparse_b = any(x.isupper() for x in expr_b)
sparse_c = any(x.isupper() for x in expr_c)
# allocate data
shape_a = [shape[d.lower()] for d in expr_a]
shape_b = [shape[d.lower()] for d in expr_b]
shape_c = [shape[d.lower()] for d in expr_c]
ref_a = torch.rand(*shape_a, device='cuda')
ref_b = torch.rand(*shape_b, device='cuda')
ref_c = torch.zeros(*shape_c, device='cuda')
# layouts
layout_a = [shape[d.lower()]//blocks[d] for d in expr_a if d.isupper()]
layout_b = [shape[d.lower()]//blocks[d] for d in expr_b if d.isupper()]
layout_c = [shape[d.lower()]//blocks[d] for d in expr_c if d.isupper()]
layout_a = torch.randint(0, 2, layout_a, device='cuda')
layout_b = torch.randint(0, 2, layout_b, device='cuda')
layout_c = torch.randint(0, 2, layout_c, device='cuda')
# triton computation
triton_a = to_sparse(expr_a, ref_a, layout_a, shape_a, blocks) if sparse_a else ref_a
triton_b = to_sparse(expr_b, ref_b, layout_b, shape_b, blocks) if sparse_b else ref_b
layouts = {expr_a: layout_a, expr_b: layout_b, expr_c: layout_c}
triton_c = triton.ops.einsum(expr, triton_a, triton_b, layouts, blocks)
torch.cuda.synchronize()
# reference computation
ref_a = to_dense(expr_a, triton_a, layout_a, shape_a, blocks) if sparse_a else ref_a
ref_b = to_dense(expr_b, triton_b, layout_b, shape_b, blocks) if sparse_b else ref_b
ref_c = torch.einsum(expr.lower(), ref_a, ref_b)
if sparse_c:
ref_c = to_sparse(expr_c, ref_c, layout_c, shape_c, blocks)
torch.cuda.synchronize()
print((ref_c - triton_c).abs().max())
# shape characteristics
test_expr('bHMK,bhkn->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'K': 32})
test_expr('bhmk,bHKN->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'K': 32, 'N': 32})
test_expr('bhmk,bhkn->bHMN', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'N': 32})

View File

@@ -171,7 +171,7 @@ class _conv(torch.autograd.Function):
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines)) _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
delta, kernel = _conv.kernel[dtype] delta, kernel = _conv.kernel[dtype]
# allocate output # allocate output
c = triton.empty([Z, CO, P, Q], dtype=dtype) c = torch.empty([Z, CO, P, Q], dtype=dtype)
# enqueue # enqueue
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')), grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
triton.cdiv(CO, opt.d('TN'))] triton.cdiv(CO, opt.d('TN'))]

View File

@@ -3,6 +3,9 @@ import triton
class _dot(torch.autograd.Function): class _dot(torch.autograd.Function):
src = """ src = """
#define STM 4
#define STN 4
__global__ void dot(TYPE * A __noalias __readonly __aligned(16), __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16), TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16), TYPE * C __noalias __aligned(16),
@@ -14,20 +17,26 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
int ldb __multipleof(8), int ldb __multipleof(8),
int ldc __multipleof(8)) { int ldc __multipleof(8)) {
// prologue // prologue
int ridx = get_program_id(0); int pid = get_program_id(0);
int ridy = get_program_id(1); int pidz = get_program_id(2);
int ridz = get_program_id(2); int gridm = M / TM;
int gridx = M / TM; int gridn = N / TN;
int gridy = N / TN; int stgridm = (gridm + STM - 1) / STM;
int rid = ridx + ridy * gridx; int stgridn = (gridn + STN - 1) / STN;
ridx = rid / gridy; int stid = pid / (STM * STN);
ridy = rid % gridy; int laneid = pid % (STM * STN);
int rm[TM] = ridx * TM + 0 ... TM; int stm = stid / stgridn;
int rn[TN] = ridy * TN + 0 ... TN; int stn = stid % stgridn;
int lanem = laneid / STN;
int lanen = laneid % STN;
int pidm = stm*STM + lanem;
int pidn = stn*STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// reduction splitting // reduction splitting
K = K / TZ; K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK; int rk[TK] = pidz * K + 0 ... TK;
// pointers to operands // pointers to operands
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
@@ -44,11 +53,11 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
// reduction loop // reduction loop
float acc[TM, TN] = 0; float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){ for(int k = K; k > 0; k -= TK){
acc += a @ b;
bool checka[TM, TK] = k > TK; bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK; bool checkb[TK, TN] = k > TK;
pa += TK * STRIDE_AK; pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK; pb += TK * STRIDE_BK;
acc += a @ b;
a = *?(checka)pa; a = *?(checka)pa;
b = *?(checkb)pb; b = *?(checkb)pb;
} }
@@ -56,8 +65,8 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE c[TM, TN] = acc; TYPE c[TM, TN] = acc;
// epilogue // epilogue
int rxm[TM] = ridx * TM + 0 ... TM; int rxm[TM] = pidm * TM + 0 ... TM;
int rxn[TN] = ridy * TN + 0 ... TN; int rxn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :]; int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
TYPE* pc[TM, TN] = C + offc; TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N); bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
@@ -66,7 +75,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
*?(checkc) pc = c; *?(checkc) pc = c;
#else #else
// accumulate partial result using spin-locks // accumulate partial result using spin-locks
int *plock = locks + rid; int *plock = locks + pid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1); int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount; int count = *pcount;
@@ -100,7 +109,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb', 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
'TM' : [128], 'TM' : [128],
'TN' : [128], 'TN' : [128],
'TK' : [16], 'TK' : [32],
'TZ' : [1] 'TZ' : [1]
} }
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines) _dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
@@ -109,9 +118,10 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
M, K = a.shape M, K = a.shape
K, N = b.shape K, N = b.shape
c = torch.empty([M,N], dtype=dtype, device=a.device) c = torch.empty([M,N], dtype=dtype, device=a.device)
print(kernel.asm('sass', c.device))
print(kernel.asm('ptx', c.device))
# enqueue # enqueue
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), grid = lambda opt: [triton.cdiv(M, opt.d('TM'))*triton.cdiv(N, opt.d('TN'))]
triton.cdiv(N, opt.d('TN'))]
time = kernel(a, b, c, 1., M, N, K, time = kernel(a, b, c, 1., M, N, K,
a.stride(0), b.stride(0), c.stride(0), grid=grid) a.stride(0), b.stride(0), c.stride(0), grid=grid)
return c return c
@@ -130,6 +140,4 @@ b = torch.rand((K, N)).cuda().half()
zc = torch.matmul(a,b) zc = torch.matmul(a,b)
zc_ = dot(a,b) zc_ = dot(a,b)
print(torch.allclose(zc, zc_)) print(torch.allclose(zc, zc_))

View File

@@ -111,7 +111,7 @@ setup(
author_email='ptillet@g.harvard.edu', author_email='ptillet@g.harvard.edu',
description='A language and compiler for custom Deep Learning operations', description='A language and compiler for custom Deep Learning operations',
long_description='', long_description='',
packages=['triton', 'triton/_C', 'triton/ops'], packages=['triton', 'triton/_C'],
install_requires=['numpy', 'torch', 'sympy'], install_requires=['numpy', 'torch', 'sympy'],
package_data={'': data}, package_data={'': data},
ext_modules=[CMakeExtension('triton', 'triton/_C/')], ext_modules=[CMakeExtension('triton', 'triton/_C/')],

View File

@@ -38,7 +38,7 @@ void delete_grid(const map_key_t& key) {
void register_fn(const map_key_t& key, void register_fn(const map_key_t& key,
const std::string& src, const std::string& src,
const rt::function::options_space_t& opt) { const rt::options_space_t& opt) {
if(id_fn_map.find(key) == id_fn_map.end()) if(id_fn_map.find(key) == id_fn_map.end())
id_fn_map[key].reset(new rt::function(src, opt, "")); id_fn_map[key].reset(new rt::function(src, opt, ""));
} }
@@ -47,9 +47,9 @@ void delete_fn(const map_key_t& key) {
id_fn_map.erase(key); id_fn_map.erase(key);
} }
std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) { std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) {
triton::driver::cu_device device(torch_get_cuda_device(key.second), false); triton::driver::cu_device device(key.second, false);
return id_fn_map[key]->ptx(&device, opt); return id_fn_map[key]->get_asm(mode, &device, opt);
} }
void cleanup() { void cleanup() {
@@ -63,7 +63,7 @@ size_t make_op_id() {
/* Function signature */ /* Function signature */
void make_module(const std::string& src, ir::module* ir, void make_module(const std::string& src, ir::module* ir,
const runtime::function::options_space_t& opt) { const runtime::options_space_t& opt) {
std::string copy = triton::runtime::function::preheader() + src; std::string copy = triton::runtime::function::preheader() + src;
// pre-process // pre-process
TokenSequence tokens; TokenSequence tokens;
@@ -80,7 +80,7 @@ void make_module(const std::string& src, ir::module* ir,
} }
std::vector<rt::arg_type> get_fn_signature(const std::string& src, std::vector<rt::arg_type> get_fn_signature(const std::string& src,
const runtime::function::options_space_t& opt) { const runtime::options_space_t& opt) {
// triton-ir code-gen // triton-ir code-gen
ir::context ctx; ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx)); auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
@@ -95,8 +95,8 @@ std::vector<rt::arg_type> get_fn_signature(const std::string& src,
return ret; return ret;
} }
typedef triton::runtime::function::options_t options_t; typedef triton::runtime::options_t options_t;
typedef triton::runtime::function::options_space_t options_space_t; typedef triton::runtime::options_space_t options_space_t;
PYBIND11_MODULE(libtriton, m) { PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API"; m.doc() = "Python bindings to the C++ Triton API";
@@ -113,6 +113,10 @@ PYBIND11_MODULE(libtriton, m) {
.value("double", rt::DOUBLE_T) .value("double", rt::DOUBLE_T)
.value("buffer", rt::BUFFER_T); .value("buffer", rt::BUFFER_T);
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
.value("ptx", rt::ASM_NV_PTX)
.value("sass", rt::ASM_NV_SASS);
pybind11::class_<options_t>(m, "options") pybind11::class_<options_t>(m, "options")
.def(pybind11::init<>()) .def(pybind11::init<>())
.def("d", &options_t::D<int>) .def("d", &options_t::D<int>)
@@ -126,7 +130,7 @@ PYBIND11_MODULE(libtriton, m) {
// hooks into triton constructs since frameworks may not use pybind11 // hooks into triton constructs since frameworks may not use pybind11
m.def("get_fn_signature", &get_fn_signature); m.def("get_fn_signature", &get_fn_signature);
m.def("get_fn_ptx", &get_fn_ptx); m.def("get_fn_asm", &get_fn_asm);
m.def("register_grid", &register_grid); m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid); m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn); m.def("register_fn", &register_fn);

View File

@@ -59,16 +59,12 @@ void synchronize(int64_t dev_id) {
} }
} }
torch::Tensor raw_like(torch::Tensor x){ torch::Tensor cuda_empty_like(torch::Tensor x){
if(x.nbytes() == 0) if(x.nbytes() == 0)
return torch::empty_like(x); return torch::empty_like(x);
C10_CUDA_CHECK(cudaSetDevice(x.device().index())); void* data;
auto shape = x.sizes(); cudaMalloc(&data, x.nbytes());
CUdeviceptr data; auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
triton::driver::dispatch::cuMemAlloc(&data, x.nbytes());
auto deleter = [data](void* ptr) { triton::driver::dispatch::cuMemFree_v2(data); };
auto ret = torch::from_blob((void*)data, shape, deleter, x.options());
ret.copy_(x);
return ret; return ret;
} }
@@ -94,6 +90,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
static auto registry = torch::RegisterOperators() static auto registry = torch::RegisterOperators()
.op("triton::launch_kernel", &launch_kernel) .op("triton::launch_kernel", &launch_kernel)
.op("triton::raw_like", &raw_like) .op("triton::cuda_empty_like", &cuda_empty_like)
.op("triton::cdiv_sum", &cdiv_sum) .op("triton::cdiv_sum", &cdiv_sum)
.op("triton::synchronize", &synchronize); .op("triton::synchronize", &synchronize);

View File

@@ -1,7 +1,4 @@
from .kernel import * from .kernel import *
import triton.ops
#import triton.nn
# clean-up libtriton resources # clean-up libtriton resources
import atexit import atexit

View File

@@ -68,8 +68,17 @@ class kernel:
size = sum([sizes[x] for x in arg_types]) size = sum([sizes[x] for x in arg_types])
self.tys = ''.join([codes[x] for x in arg_types]) self.tys = ''.join([codes[x] for x in arg_types])
def ptx(self, device, **kwargs): def asm(self, mode, device, **kwargs):
dev_id = device.index dev_id = device.index
# assembly mode
supported = {
'ptx': libtriton.asm_mode.ptx,
'sass': libtriton.asm_mode.sass,
}
if mode not in supported:
raise('ASM mode must be in ', supported.keys())
mode = supported[mode]
# disambiguates #defines
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt) libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
def _single_value_or_err(x, key): def _single_value_or_err(x, key):
if isinstance(x, list) and len(x) == 1: if isinstance(x, list) and len(x) == 1:
@@ -86,15 +95,18 @@ class kernel:
opt = libtriton.options() opt = libtriton.options()
opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps') opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps')
opt.defines = defines opt.defines = defines
return libtriton.get_fn_ptx((self.op_id, dev_id), opt) # run
return libtriton.get_fn_asm((self.op_id, dev_id), mode, opt)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if 'TRITON_DEBUG_MODE' in os.environ: if 'TRITON_DEBUG_MODE' in os.environ:
_args = args _args = args
args = [x for x in args] args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args]
for i in range(len(args)): for i in range(len(args)):
if isinstance(args[i], torch.Tensor): if isinstance(args[i], torch.Tensor):
args[i] = torch.ops.triton.raw_like(args[i]) args[i] = torch.ops.triton.cuda_empty_like(args[i])
args[i].copy_(_args[i])
torch.cuda.synchronize()
for x in args: for x in args:
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
device = x.device.index device = x.device.index
@@ -116,6 +128,8 @@ class kernel:
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else [] constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants) torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
if 'TRITON_DEBUG_MODE' in os.environ: if 'TRITON_DEBUG_MODE' in os.environ:
torch.cuda.synchronize()
for i in range(len(args)): for i in range(len(args)):
if isinstance(args[i], torch.Tensor): if isinstance(args[i], torch.Tensor):
_args[i].copy_(args[i]) _args[i].copy_(args[i].clone())
args = _args

View File

@@ -1,2 +0,0 @@
from .conv import replace_conv2d
from .attention import replace_mah

View File

@@ -1,312 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
def bmm(x, w, mask = None):
b, m, k = x.size()
b, k, n = w.size()
out = torch.empty([b, m, n], device=x.device)
triton.ops.einsum('bmk,bkn->bmn', x, w, out, mask=mask, bench=False)
return out
def multi_head_attention_forward(query, # type: Tensor
key, # type: Tensor
value, # type: Tensor
embed_dim_to_check, # type: int
num_heads, # type: int
in_proj_weight, # type: Tensor
in_proj_bias, # type: Tensor
bias_k, # type: Optional[Tensor]
bias_v, # type: Optional[Tensor]
add_zero_attn, # type: bool
dropout_p, # type: float
out_proj_weight, # type: Tensor
out_proj_bias, # type: Tensor
training=True, # type: bool
key_padding_mask=None, # type: Optional[Tensor]
need_weights=True, # type: bool
attn_mask=None, # type: Optional[Tensor]
use_separate_proj_weight=False, # type: bool
q_proj_weight=None, # type: Optional[Tensor]
k_proj_weight=None, # type: Optional[Tensor]
v_proj_weight=None, # type: Optional[Tensor]
static_k=None, # type: Optional[Tensor]
static_v=None, # type: Optional[Tensor]
acc_bitmask=None
):
# type: (...) -> Tuple[Tensor, Optional[Tensor]]
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: mask that prevents attention to certain positions. This is an additive mask
(i.e. the values will be added to the attention layer).
use_separate_proj_weight: the function accept the proj. weights for query, key,
and value in differnt forms. If false, in_proj_weight will be used, which is
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
static_k, static_v: static key and value used for attention operators.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
- attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size() == value.size()
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if not use_separate_proj_weight:
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
if key is None:
assert value is None
k = None
v = None
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = F.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = F.linear(value, _w, _b)
else:
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
len1, len2 = q_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == query.size(-1)
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
len1, len2 = k_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == key.size(-1)
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
len1, len2 = v_proj_weight_non_opt.size()
assert len1 == embed_dim and len2 == value.size(-1)
if in_proj_bias is not None:
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
else:
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
q = q * scaling
if bias_k is not None and bias_v is not None:
if static_k is None and static_v is None:
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask,
torch.zeros((attn_mask.size(0), 1),
dtype=attn_mask.dtype,
device=attn_mask.device)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
dtype=key_padding_mask.dtype,
device=key_padding_mask.device)], dim=1)
else:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
else:
assert bias_k is None
assert bias_v is None
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is not None:
assert static_k.size(0) == bsz * num_heads
assert static_k.size(2) == head_dim
k = static_k
if static_v is not None:
assert static_v.size(0) == bsz * num_heads
assert static_v.size(2) == head_dim
v = static_v
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if add_zero_attn:
src_len += 1
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1),
dtype=attn_mask.dtype,
device=attn_mask.device)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
dtype=key_padding_mask.dtype,
device=key_padding_mask.device)], dim=1)
attn_output_weights = bmm(q, k.transpose(1, 2), mask=acc_bitmask)
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
attn_output_weights = F.softmax(
attn_output_weights, dim=-1)
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = bmm(attn_output_weights, v, mask=acc_bitmask)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None
class MultiheadAttention(nn.modules.activation.MultiheadAttention):
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, acc_bitmask=None):
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim)
self.acc_bitmask = acc_bitmask
def forward(self, query, key, value, key_padding_mask=None,
need_weights=True, attn_mask=None):
if not self._qkv_same_embed_dim:
return multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
acc_bitmask=self.acc_bitmask)
else:
return multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask,
acc_bitmask=self.acc_bitmask)
def replace_mah(model, mask = None):
for child_name, child in model.named_children():
if isinstance(child, nn.modules.activation.MultiheadAttention):
add_bias_kv = child.bias_k is not None
device = child.in_proj_weight.device
mah = MultiheadAttention(child.embed_dim, child.num_heads,
dropout=child.dropout, add_bias_kv=add_bias_kv,
add_zero_attn=child.add_zero_attn, kdim=child.kdim,
vdim=child.vdim, acc_bitmask=mask).to(device)
for yparam, xparam in zip(mah.parameters(), child.parameters()):
yparam.data.copy_(xparam.data)
setattr(model, child_name, mah)
else:
replace_mah(child, mask)

View File

@@ -1,166 +0,0 @@
import triton
import torch.nn as nn
import torch
import torch.nn.functional as F
class _conv2d(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias,
stride, padding, dilation, groups,
acc_bitmask):
assert dilation == (1, 1)
assert groups == 1
assert bias == None
pad_h, pad_w = padding
stride_h, stride_w = stride
n, c, h, w = x.size()
k, c, r, s = weight.size()
# allocate output
p = (h + 2*padding[0] - r)//stride[0] + 1
q = (w + 2*padding[1] - s)//stride[1] + 1
output = torch.empty((n, k, p, q), dtype=x.dtype, device=x.device)
# padding
if pad_h or pad_w:
x = triton.ops._einsum.pad(x, [pad_w, pad_w, pad_h, pad_h])
# convolution
triton.ops.einsum(f'nc(h*stride_h + r - pad_h)(w*stride_w + s - pad_w),kcrs->nkhw',
x, weight, mask=acc_bitmask,
output=output,
values = {'pad_h': pad_h,
'stride_h': stride_h,
'pad_w': pad_w,
'stride_w': stride_w})
# prepare backprop
ctx.save_for_backward(x, weight)
ctx.stride = stride
ctx.padding = padding
ctx.acc_bitmask = acc_bitmask
# return
return output
@staticmethod
def backward(ctx, dy):
# retrieve contextual information
x, weight = ctx.saved_tensors
stride = ctx.stride
padding = ctx.padding
acc_bitmask = ctx.acc_bitmask
# gradient of the input
dx = None
if ctx.needs_input_grad[0]:
# dy must be padded
n, k, p, q = dy.size()
n, c, h, w = x.size()
k, c, r, s = weight.size()
dypad = triton.ops._einsum.pad(dy, [4, 4, 4, 4])
# have to be careful here
# the gradient of strided conv is a conv over a sparse image
# which can be decomposed as a set of smaller convs
dx = torch.empty_like(x)
for offh in range(stride[0]):
for offw in range(stride[1]):
poffh = (offh + padding[0]) % stride[0]
poffw = (offw + padding[1]) % stride[1]
pad_h = int((padding[0] + (stride[0] - 1)*offh) / stride[0])
pad_w = int((padding[1] + (stride[1] - 1)*offw) / stride[1])
if poffh >= r or poffw >= s:
dx[:, :, offh::stride[0], offw::stride[1]] = 0
else:
triton.ops.einsum(f'nk(h - r + pad_h)(w - s + pad_w),kcrs->nchw',
dypad[:, :, :, :],
weight[:, :, poffh::stride[0], poffw::stride[1]],
output = dx[:, :, offh::stride[0], offw::stride[1]],
mask = acc_bitmask,
values = {'pad_h': pad_h,
'pad_w': pad_w})
# gradient for the weight
dw = None
if ctx.needs_input_grad[1]:
dw = torch.empty_like(weight)
triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs',
x, dy, output = dw, mask = acc_bitmask)
#print('dw: ', dw.view(-1)[0])
return dx, dw, None, None, None, None, None, None
conv2d = _conv2d.apply
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros',
acc_bitmask = None):
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode)
self.acc_bitmask = acc_bitmask
def forward(self, input):
#if self.kernel_size[0] == 3 and self.stride[0] != 1:
#print(self.padding, self.stride, input.size(), self.weight.size())
# return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return conv2d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups,
self.acc_bitmask)
def replace_conv2d(model, acc_bitmask = None):
for child_name, child in model.named_children():
if isinstance(child, nn.Conv2d):
conv2d = Conv2d(child.in_channels, child.out_channels, child.kernel_size,
child.stride, child.padding, child.dilation, child.groups,
child.bias, child.padding_mode,
acc_bitmask=acc_bitmask)
for yparam, xparam in zip(conv2d.parameters(), child.parameters()):
yparam.data.copy_(xparam.data)
setattr(model, child_name, conv2d)
else:
replace_conv2d(child, acc_bitmask)
# initialize input
#N, C, H, W, K, RS = 16, 32, 24, 24, 64, 3
#torch.Size([128, 64, 30, 30]) torch.Size([128, 64, 3, 3])
#torch.Size([128, 128, 15, 15]) torch.Size([256, 128, 3, 3])
#torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3])
if __name__ == '__main__':
N, C, H, W, K, RS = 128, 64, 30, 30, 128, 3
#N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3
#N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3
pad, stride = 0, 1
torch.manual_seed(0)
x = torch.randn((N, C, H, W)).cuda()
x.requires_grad_(True)
#x.data[:] = 1
# initialize layers
torch.manual_seed(0)
rconv2d = nn.Conv2d(C, K, RS, stride, pad, bias=False).cuda()
torch.manual_seed(0)
tconv2d = Conv2d(C, K, RS, stride, pad, bias=False).cuda()
#rconv2d.weight.data[:] = 1
#tconv2d.weight.data[:] = 1
ry = rconv2d(x)
ty = tconv2d(x)
# reference
dy = torch.randn(ry.size()).cuda()
#dy.data[:] = 1
ry.backward(dy)
rdx = x.grad.clone()
rdw = rconv2d.weight.grad.clone()
x.grad.zero_()
# triton
ty.backward(dy)
tdx = x.grad.clone()
tdw = tconv2d.weight.grad.clone()
x.grad.zero_()
# print error
diff = lambda x, y: (x - y).abs().max()
print(diff(ry, ty))
print(diff(rdx, tdx))
print(diff(rdw, tdw))
#print((rdx - tdx).abs())
#print((rdx[0,0,:,:] - tdx[0,0,:,:]))
#print(rdx[0,0,:,:])
#print(tdx[0,0,:,:])

View File

@@ -1,13 +0,0 @@
import torch
import triton
def linear(x, w, bias = None):
print(x.size(), w.size())
m, k = x.size()
k, n = w.size()
out = torch.empty([m, n], device=x.device)
triton.ops.einsum('mk,nk->mn', x, w, bias)
if bias is not None:
out += bias
return out

View File

@@ -1,2 +0,0 @@
from .einsum import _einsum, einsum
from .batchnorm import _batchnorm, batchnorm

View File

@@ -1,136 +0,0 @@
import triton
import torch
import math
class _batchnorm(torch.autograd.Function):
fwd_src = """
void fwdbatchnorm(float *Y, float *M, float *V,
float *X, float *G, float *B,
int N, float eps) {
// pointers
int c = get_program_id(1);
int rm[TM] = 0 ... TM;
float *px[TM] = X + rm + c*N;
float* py[TM] = Y + rm + c*N;
// compute mean
float accm[TM] = 0;
for(int i = 0; i < N; i = i + TM)
accm = accm + *(px + i);
float mean = (float)accm[+] / N;
*(M + c) = mean;
// compute variance
float accv[TM] = 0;
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
x = x - mean;
accv = accv + x*x;
}
float var = (float)accv[+] / N;
*(V + c) = var;
// Normalize batch
float gamma = *(G + c);
float beta = *(B + c);
float rstdg = 1 / sqrtf(var + eps) * gamma;
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
float y[TM] = (x - mean)*rstdg + beta;
*(py + i) = y;
}
}
"""
bwd_src = """
void bwdbatchnorm(float *DX, float *DG, float *DB,
float *DY, float *X, float *G,
float *M, float *V,
int N, float epsilon) {
// pointers
int c = get_program_id(1);
int rx[TM] = 0 ... TM;
int offset = c*N;
float* px[TM] = X + rx + offset;
float* pdy[TM] = DY + rx + offset;
float* pdx[TM] = DX + rx + offset;
// fetch statistics
float gamma = *(G + c);
float mean = *(M + c);
float var = *(V + c);
float rstd = 1 / sqrtf(var + epsilon);
// compute dgamma and dbeta
float acc_dg[TM] = 0;
float acc_db[TM] = 0;
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
float dy[TM] = *(pdy + i);
acc_dg += dy*(x - mean)*rstd;
acc_db += dy;
}
float dg = acc_dg[+];
float db = acc_db[+];
*(DG + c) = dg;
*(DB + c) = db;
// compute dx
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
float dy[TM] = *(pdy + i);
float xhat[TM] = (x - mean) * rstd;
float xtmp[TM] = (xhat * dg + db) / N;
float dx[TM] = (dy - xtmp) * rstd * gamma;
*(pdx + i) = dx;
}
}
"""
fwd_kernel = None
bwd_kernel = None
@staticmethod
def forward(ctx, x, gamma, beta, eps):
# lazy compilation of kernel
if _batchnorm.fwd_kernel is None:
_batchnorm.fwd_kernel = triton.kernel(fwd_src, defines = {'TM': 128})
# shapes
shape = triton.shape(x)
dtype = x.dtype
# allocate outputs
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
y = triton.empty(shape, dtype=dtype)
mean = triton.empty([C], dtype=dtype)
var = triton.empty([C], dtype=dtype)
# execute kernels
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
grid = lambda opt: [1, C])
# save
ctx.save_for_backward(x, gamma, beta, mean, var)
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
# lazy compilation of kernel
if _batchnorm.bwd_kernel is None:
_batchnorm.bwd_kernel = triton.kernel(bwd_src, defines = {'TN': 128})
# retrieve info
x, gamma, beta, mean, var = ctx.saved_tensors
eps = ctx.eps
# allocate result
dx = triton.empty(triton.shape(x), dtype=x.dtype)
dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
# execute
C, H, W, B = triton.shape(x)
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
x, gamma, mean, var,
H*W*B, eps,
grid = lambda opt: [1, C])
return dx, dgamma, dbeta, None
batchnorm = _batchnorm.apply

View File

@@ -1,794 +0,0 @@
from math import ceil, log2
from enum import IntEnum
from functools import reduce
from operator import mul
from collections import OrderedDict
from collections import namedtuple
import re
import string
import triton
import torch
# numpy -- ideally removed in a future release
import numpy as np
# sympy -- ideally removed in a future release
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from sympy.printing.ccode import C89CodePrinter
class _einsum(torch.autograd.Function):
#############################
## Triton-C code generation
#############################
def print_cc(expr, axes_0, axes_1, axes_2, prefix):
if expr in axes_0:
return f'{prefix}r{expr}[:, newaxis, newaxis]'
if expr in axes_1:
return f'{prefix}r{expr}[newaxis, :, newaxis]'
if expr in axes_2:
return f'{prefix}r{expr}[newaxis, newaxis, :]'
return expr
def unpack_cc(tile, axes, prefix, remat):
ret = ''
axes = list(map(str, axes))
for i, d in enumerate(reversed(axes)):
if i == len(axes) - 1:
break
currs = ''.join(axes[: len(axes) - i])
nexts = ''.join(axes[: len(axes) - (i + 1)])
ty = '' if remat else 'int '
sz = '' if remat or tile is None else f'[{tile}]'
ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n'
ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n'
return ret
def strides_cc(name, expr):
ret = [f'stride_{name}_{d}' for d in expr[:-1]] + ['1']
ret = dict(zip(expr, ret))
return ret
def make_kernel(name, dtype,
expr_a, expr_b, expr_c,
sparse_a, sparse_b, sparse_c,
axes_m, axes_n, axes_k, axes_b,
multipleof_a, multipleof_b, multipleof_c,
stride_a_last, stride_b_last, stride_c_last,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
blocks):
use_lut_a = True
use_lut_b = True
outer_sparse_a = [x for x in expr_a if x in sparse_a and x not in axes_k]
outer_dense_a = [x for x in expr_a if x not in sparse_a and x not in axes_k]
outer_sparse_b = [x for x in expr_b if x in sparse_b and x not in axes_k]
outer_dense_b = [x for x in expr_b if x not in sparse_b and x not in axes_k]
outer_dense_c = [x for x in expr_c if x not in sparse_c and x not in axes_k]
src = ""
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
src += f"""
char __constant__* AD = calloc({4*len(delta_a)});"""
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
src += f"""
char __constant__* BD = calloc({4*len(delta_b)});"""
src += f"""
__global__ void {name}(
TYPE * A __noalias __readonly __aligned(16)
, TYPE * B __noalias __readonly __aligned(16)
, TYPE * C
, int * locks
, float alpha
, int matmul_m, int matmul_n, int matmul_k __multipleof(16)
, int div_m
"""
for dim in [axes_m, axes_n, axes_k, axes_b]:
for d in dim:
src += f", int dim_{d}"
src += "\n "
for dim, name, mult, sparse in zip([expr_a, expr_b, expr_c],
['a', 'b', 'c'],
[multipleof_a, multipleof_b, multipleof_c],
[sparse_a, sparse_b, sparse_c]):
for d in range(len(dim) - 1):
if sparse and dim[d] == sparse[0]:
src += f', int stride_{name}_block __multipleof({mult})'
src += f", int stride_{name}_{d} __multipleof({mult})"
src += "\n "
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += f", int stride_a_inner __multipleof({multipleof_a})"
src += f", int rem_delta_a __multipleof({multipleof_a})"
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
src += ", int* AD __noalias __readonly __aligned(16)"
src += "\n "
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += f", int stride_b_inner __multipleof({multipleof_b})"
src += f", int rem_delta_b __multipleof({multipleof_b})"
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
src += ", int* BD"
src += "\n "
if sparse_c:
src += ", int* CD"
if sparse_a or sparse_b:
src += ", int width"
src += """) {
// program identifiers
int pid_0 = get_program_id(0);
int pid_1 = get_program_id(1);
"""
if sparse_a:
src += f"""
int off_n = pid_0 / width;
int off_header = pid_0 % width;
int* header = AD + off_header * {2 + len(outer_sparse_a)};
int* pdelta = AD + *(header + 0);
matmul_k = *(header + 1);"""
for i, d in enumerate(outer_sparse_a):
src += f"""
int off_{d} = *(header + {2 + i});"""
src += f"""
int inca = *(pdelta + 0);
int incb = *(pdelta + 1);
int off_{''.join(map(str, outer_dense_a))} = pid_1;
"""
_einsum.unpack_cc(None, outer_dense_a, "off_", False)
elif sparse_b:
src += f"""
int off_m = pid_0 / width;
int off_header = pid_0 % width;
int* header = BD + off_header * {2 + len(outer_sparse_b)};
int* pdelta = BD + *(header + 0);
matmul_k = *(header + 1);"""
for i, d in enumerate(outer_sparse_b):
src += f"""
int off_{d} = *(header + {2 + i});"""
src += f"""
int incb = *(pdelta + 0);
int inca = *(pdelta + 1);
int off_{''.join(map(str, outer_dense_b))} = pid_1;
"""
_einsum.unpack_cc(None, outer_dense_b, "off_", False)
elif sparse_c:
src += f"""
// load LUT header
int *header = CD + pid_0 * {len(sparse_c)};"""
for i, d in enumerate(sparse_c):
src += f"""
int off_{d} = *(header + {i});"""
src += f"""
int off_{''.join(map(str, outer_dense_c))} = pid_1;"""
else:
src += """
// re-order outer program ids
int grid_m = (matmul_m + TM - 1) / TM;
int grid_n = (matmul_n + TN - 1) / TN;
int off_mn = pid_0 / div_m;
int off_n = off_mn % grid_n;
int off_m = (off_mn / grid_n)*div_m + (pid_0 % div_m);
int off_b = get_program_id(1);"""
src += """
#if TZ == 1
int off_k = 0;
#else
// get reduction sub-group program id
int pid_z = get_program_id(2);
int grid_z = get_num_programs(2);
int div_z = matmul_k / TZ;
int rem_z = matmul_k % TZ;
int off_k = pid_z * div_z;
matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
#endif
int rem_k = matmul_k % TK;
// create ranges
"""
sparse = sparse_a + sparse_b + sparse_c
for axes, tile, off, prefixes in zip([axes_m, axes_n, axes_b, axes_k],
['TM', 'TN', 'TB', 'TK'],
['off_m*TM', 'off_n*TN', 'off_b*TB', 'off_k'],
[['a', 'c'], ['b', 'c'], ['a', 'b', 'c'], ['a', 'b']]):
if not axes:
continue
currs = ''.join(map(str,axes))
has_sparse_component = set(axes) & set(sparse)
if has_sparse_component:
src += f" int r{currs}[{tile}] = 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, f'r', False)
else:
src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
src += _einsum.unpack_cc(tile, axes, f'r', False)
for pfx in prefixes:
for d in axes:
is_dense_dim = d not in sparse
is_dense_storage = (pfx == 'a' and not sparse_a) or\
(pfx == 'b' and not sparse_b) or\
(pfx == 'c' and not sparse_c)
if not is_dense_dim and is_dense_storage:
src += f" int {pfx}r{d}[{tile}] = off_{d} * BLOCK{d.upper()} + r{d};\n"
elif is_dense_dim and has_sparse_component:
src += f" int {pfx}r{d}[{tile}] = off_{d};\n"
else:
src += f" int {pfx}r{d}[{tile}] = r{d};\n"
src += f"""
// initialize pointers to A
int offa[TM, TK, TB] = {'inca' if sparse_a or sparse_b else '0'} """
for i, sym in enumerate(expr_a):
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b, 'a')
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}'
src += f" + ({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pa[TM, TK, TB] = A + offa;"""
if not sparse_a and not sparse_b and use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
src += f"""
int offadelta[TK] = off_k + 0 ... TK;
int {spec} *padelta[TK] = {cast}AD + offadelta;
int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];"""
src += f"""
// initialize pointers to B
int offb[TK, TN, TB] = {'incb' if sparse_a or sparse_b else '0'}"""
for i, sym in enumerate(expr_b):
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b, 'b')
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}'
src += f" + ({ccode}) * {stride}\n "
src += ';'
src += """
TYPE *pb[TK, TN, TB] = B + offb;"""
if not sparse_a and not sparse_b and use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
src += f"""
// initialize pointers to B look-up table
int offbdelta[TK] = off_k + 0 ... TK;
int *pbdelta[TK] = BD + offbdelta;"""
rk = 'r{}'.format(''.join(map(str,axes_k)))
src += f"""
// prefetch
int prefetch_k = select(rem_k > 0, rem_k, TK);
bool checkam[TM] = ar""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
bool checkbn[TN] = br""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
bool checkk[TK] = r{''.join(map(str, axes_k))} < prefetch_k;
bool checka[TM, TK, TB] = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
TYPE a[TM, TK, TB] = checka ? *pa : 0;
TYPE b[TK, TN, TB] = checkb ? *pb : 0;"""
if sparse_a:
src += f"""
// update pointers to look-up tables
pdelta += 2;
int incda = *(pdelta + 0);
int incdb = *(pdelta + 1);
pa += incda;
pb += incdb;"""
if sparse_b:
src += f"""
// update pointers to look-up tables
pdelta += 2;
int incdb = *(pdelta + 0);
int incda = *(pdelta + 1);
pa += incda;
pb += incdb;"""
if not sparse_a and not sparse_b and lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += """
pa += rem_delta_a;"""
elif not sparse_a and not sparse_b:
src += """
pa += incda;
padelta += TK;
incda = (*padelta)[newaxis, :, newaxis];"""
if not sparse_a and not sparse_b and lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += """
pb += rem_delta_b;"""
elif not sparse_a and not sparse_b:
src += """
pb += (*pbdelta)[:, newaxis, newaxis];
pbdelta += TK;"""
src += f"""
// accumulate
float acc[TM, TN, TB] = 0;
for(int k = matmul_k; k > 0; k -= TK) {{
acc += a @ b;
// load inputs
checkk = k > TK;
checka = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
checkb = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
a = *?(checka)pa;
b = *?(checkb)pb;
// update pointers"""
if sparse_a:
src += """
pdelta += 2;
incda = *(pdelta + 0);
incdb = *(pdelta + 1);
pa += incda;
pb += incdb;
"""
elif sparse_b:
src += """
pdelta += 2;
incdb = *(pdelta + 0);
incda = *(pdelta + 1);
pa += incda;
pb += incdb;
"""
else:
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += """
pa += stride_a_inner;"""
else:
src += """
pa += incda;
padelta += TK;
incda = (*padelta)[newaxis, :, newaxis];"""
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += """
pb += stride_b_inner;"""
else:
src += """
pb += (*pbdelta)[:, newaxis, newaxis];
pbdelta += TK;"""
src += f"""
}}
TYPE c[TM, TN, TB] = acc;
// initialize pointers to C
int offc[TM, TN, TB] = {'pid_0*TN*TN' if sparse_c else 0}"""
for i, sym in enumerate(expr_c):
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}'
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b, 'c')
src += f"\n + ({ccode}) * {stride}"
src += ';'
src += """
TYPE *pc[TM, TN, TB] = C + offc;
// bounds-checking
bool checkcm[TM] = cr""" + ''.join(map(str,axes_m)) + """ < matmul_m;
bool checkcn[TN] = cr""" + ''.join(map(str,axes_n)) + """ < matmul_n;
bool checkc[TM, TN, TB] = checkcm[:, newaxis, newaxis] &&
checkcn[newaxis, :, newaxis];
// write back
#if TZ == 1
*?(checkc)pc = c;
#else
int *plock = locks + pid_mn + pid_b * get_num_programs(0);
int *pcount = plock + 1024*1024;
// spin
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc)pc = c;
else
*?(checkc)pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % (grid_z));
atomic_xchg(plock, 0);
#endif
}
"""
# compilation options
TM, TN, TB, TZ = [32], [32], 1, [1]
TK = 16 if dtype==torch.float16 else 8
defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
for d, B in blocks.items():
defines[f'BLOCK{d}'] = B
# create kernel
ret = triton.kernel(src, defines=defines)
# set constant
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('AD', delta_a)
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('BD', delta_b)
return ret
############################
## Look-up Table
############################
class LUT_MODE(IntEnum):
SCALAR = 1
CONSTANT = 2
DRAM = 3
def lut_mode(delta):
if delta.size == 0 or np.min(delta) == np.max(delta):
return _einsum.LUT_MODE.SCALAR
#if delta.size < 4096:
# return _einsum.LUT_MODE.CONSTANT
return _einsum.LUT_MODE.DRAM
def symbolic_delta(symbols, axes):
rank = len(symbols)
strides = [sp.symbols(f'stride{d}') for d in range(rank)]
nexts = {s: sp.symbols(f'next{s}') for s in axes}
delta = 0
for i in range(rank):
delta += strides[i] * (symbols[i].subs(nexts) - symbols[i])
return delta
def unpack_offset(k, axes, dims):
ret = dict()
for d in reversed(axes):
ret[d] = k % dims[d]
k = k // dims[d]
return ret
def make_dsd_delta(axes, step, stride, dims, symbols, sparse, layout, blocks):
# depth of reductions
depth = layout.sum(*[i for i, d in enumerate(sparse) if d in axes])
# outer dimension indices
outer = torch.nonzero(depth, as_tuple=False)
outer = [outer[:,i] for i in range(outer.shape[1])]
# find offset of outer dimensions
depth = depth.view(-1)
offsets = torch.zeros_like(depth)
offsets[1:] = torch.cumsum(depth[:-1], 0)
# compute delta for b
# TODO: support multiple sparse red indices
col = next((i for i, d in enumerate(sparse) if d in axes), None)
block = blocks[sparse[-1].upper()]
div = block // step
delta_b = torch.nonzero(layout.transpose(-1, col), as_tuple=False)[:, -1].reshape(-1).contiguous()
delta_b *= block
delta_b = [delta_b + step*i for i in range(div)]
delta_b = torch.stack(delta_b, dim=1)
delta_b = delta_b.view(-1)
# compute delta for a
bstride = 1
for d in sparse[::-1]:
if d in axes:
break
bstride *= blocks[d.upper()]
order = [d for d in sparse if d not in axes] +\
[d for d in sparse if d in axes]
idx = [sparse.index(d) for d in order]
layout[layout > 0] = 1 + torch.arange(layout.sum(), device=layout.device)
layout = layout.permute(*idx)
delta_a = layout[layout > 0] - 1
delta_a *= np.prod(list(blocks.values()))
saved = delta_a[offsets]
delta_a[1:] = delta_a[1:] - delta_a[:-1]
delta_a = delta_a.view(-1, 1).repeat(1, div)
delta_a[:, 1:] = step*bstride
delta_a[:, 0] -= (div - 1)*step*bstride
delta_a[offsets, 0] = saved
delta_a = delta_a.view(-1)
delta = torch.stack((delta_a, delta_b), dim=1).view(-1).contiguous()
# form look-up table
depth *= blocks[symbols[-1].upper()]
offsets *= div
header = torch.stack((offsets, depth, *outer), dim=1).view(-1).contiguous()
nouter = 2 + len(outer)
header[::nouter] = header[::nouter]*2 + header.shape[0]
lut = torch.cat((header, delta)).int().int().cpu().numpy()
return lut, nouter, _einsum.LUT_MODE.DRAM
def make_delta(axes, step, stride, dims, symbols, sparse, layout, lut = None, nouter = None):
# symbolic pointer increments
symbols = [sp.symbols(x) for x in symbols]
delta = _einsum.symbolic_delta(symbols, axes)
args = [f'stride{d}' for d in range(len(stride))]
args += [f'{sk}' for sk in axes]
args += [f'next{sk}' for sk in axes]
fn = sp.lambdify(args, delta, 'numpy')
if lut is None:
# inner axes values
inner = [dims[d] for d in axes]
inner = np.prod(inner)
rem = inner % step
rem = rem if rem > 0 else step
# k = [0, 1, ..., step, rem, rem + 1, ... rem + inner]
# nextk = [rem, 1 + rem, ..., step + rem, rem + step, rem + 1 + step, ..., inner + step]
k = np.concatenate((np.arange(step), np.arange(rem, inner))).astype(np.int32)
nextk = np.concatenate((k[:step] + rem, k[step:] + step))
else:
idx = (lut[:lut[0]:nouter] - lut[0])//2
k = lut[lut[0]+1::2]
k = np.insert(k, idx, 0)
nextk = k[1:]
k = k[:-1]
# offsets
off = _einsum.unpack_offset(k, axes, dims)
nextoff = _einsum.unpack_offset(nextk, axes, dims)
# evaluate deltas
args = [s for s in stride]
args += [off[sk] for sk in axes]
args += [nextoff[sk] for sk in axes]
delta = fn(*args)
delta = np.maximum(delta, 0)
if lut is not None:
idx = idx[1:] + np.arange(idx.shape[0] - 1)
delta = np.delete(delta, idx)
lut[lut[0]+1::2] = delta
return None, None
return delta, _einsum.lut_mode(delta[step:-step])
@staticmethod
def make_sdd_lut(layout_c, sparse_c, blocks):
nnz = torch.nonzero(layout_c, as_tuple=False)
lut = nnz.reshape(-1).int().cuda()
return lut
############################
## Compilation
############################
class instance:
locks = None
kernel_cache = dict()
def __init__(self, einsum, dtype, stride_a, stride_b, shape_a, shape_b, layouts, blocks):
# parse symbols
expr_a, expr_bc = einsum.split(",")
expr_b, expr_c = expr_bc.split("->")
sym_a = expr_a.lower()
sym_b = expr_b.lower()
sym_c = expr_c.lower()
sparse_a = [x.lower() for x in expr_a if x.isupper()]
sparse_b = [x.lower() for x in expr_b if x.isupper()]
sparse_c = [x.lower() for x in expr_c if x.isupper()]
layout_a = layouts.get(expr_a)
layout_b = layouts.get(expr_b)
layout_c = layouts.get(expr_c)
# parse axes
axes_b = [d for d in sym_a if d in sym_b and d in sym_c]
axes_m = [d for d in sym_a if d not in sym_b and d in sym_c]
axes_k = [d for d in sym_a if d in sym_b and d not in sym_c]
axes_n = [d for d in sym_b if d not in sym_a and d in sym_c]
axes = axes_b + axes_m + axes_n + axes_k
# check block sizes
for d in sparse_a + sparse_b + sparse_c:
if d.upper() not in blocks:
raise ValueError(f'unspecified block size for dimension: {d.upper()}')
# check layout is present
if sparse_a and layout_a is None:
raise ValueError('A is sparse but not layout provided')
if sparse_b and layout_b is None:
raise ValueError('B is sparse but not layout provided')
if sparse_c and layout_c is None:
raise ValueError('C is sparse but not layout provided')
# check dimensions
dims_a = dict([(x, y) for x,y in zip(sym_a, shape_a) if x not in sparse_a])
dims_b = dict([(x, y) for x,y in zip(sym_b, shape_b) if x not in sparse_b])
dims_La = None if layout_a is None else dict(zip([x for x in expr_a if x.isupper()], layout_a.shape))
dims_Lb = None if layout_b is None else dict(zip([x for x in expr_b if x.isupper()], layout_b.shape))
# TODO: could be cleaner
read_shape = lambda d, dimsT, dimsL, sparse: dimsL[d.upper()] * blocks[d.upper()] if d in sparse else dimsT[d]
for d in axes_b + axes_m + axes_n + axes_k:
dim_a = read_shape(d, dims_a, dims_La, sparse_a) if d in sym_a else None
dim_b = read_shape(d, dims_b, dims_Lb, sparse_b) if d in sym_b else None
if d in axes_b and dim_a and dim_b and dim_a != dim_b:
raise ValueError(f'incomparible batch dimension {d} (A: {dim_a}, B: {dim_b})')
if d in axes_k and dim_a and dim_b and dim_a != dim_b:
raise ValueError(f'incompatible inner dimension {d} (A: {dim_a}, B: {dim_b})')
dims = dict()
dims.update(dims_a)
dims.update(dims_b)
for i, d in enumerate(sparse_a):
dims[d] = layout_a.shape[i] * blocks[d.upper()]
for i, d in enumerate(sparse_b):
dims[d] = layout_b.shape[i] * blocks[d.upper()]
# allocate output
shape_c = [dims[d] if d.islower() else blocks[d] for d in expr_c]
if sparse_c:
shape_c.insert(expr_c.index(sparse_c[0].upper()), int(layout_c.sum()))
stride_c = [None] * len(shape_c)
stride_c[-1] = 1
for i in reversed(range(len(shape_c) - 1)):
stride_c[i] = stride_c[i+1] * shape_c[i+1]
# look-up tables
TK = 16 if dtype == torch.float16 else 8
if sparse_a and not sparse_b:
delta_a, nouter, lut_mode_a = _einsum.make_dsd_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, blocks)
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, delta_a, nouter)
if sparse_b and not sparse_a:
delta_b, nouter, lut_mode_b = _einsum.make_dsd_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, blocks)
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, delta_b, nouter)
if not sparse_a and not sparse_b:
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a)
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b)
if sparse_c:
delta_c = _einsum.make_sdd_lut(layout_c, sparse_c, blocks)
# hash for recompilation
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
stride_a_last = stride_a[-1]
stride_b_last = stride_b[-1]
stride_c_last = stride_c[-1]
name = f'{dtype}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\
f'_{stride_a_last}_{stride_b_last}_{stride_c_last}'
# recompile if necessary
cache = _einsum.instance.kernel_cache
if name not in cache:
cachesize = len(cache)
cache[name] = _einsum.make_kernel(f'__einsum{cachesize}',
dtype,
sym_a, sym_b, sym_c,
sparse_a, sparse_b, sparse_c,
axes_m, axes_n, axes_k, axes_b,
stride_a_multiple, stride_b_multiple, stride_c_multiple,
stride_a_last, stride_b_last, stride_c_last,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
blocks)
self.kernel = cache[name]
# Initialize locks
if _einsum.instance.locks is None:
_einsum.instance.locks = torch.zeros(2*1024*1024, dtype=torch.int32).cuda()
# Kernel arguments
dim_m = [dims[d] for d in axes_m]
dim_n = [dims[d] for d in axes_n]
dim_k = [dims[d] for d in axes_k]
dim_b = [dims[d] for d in axes_b]
M = reduce(mul, dim_m, 1)
N = reduce(mul, dim_n, 1)
K = reduce(mul, dim_k, 1)
B = reduce(mul, [dims[d] for d in axes_b if d.upper() not in einsum], 1)
stride_a = list(stride_a[:-1])
stride_b = list(stride_b[:-1])
stride_c = list(stride_c[:-1])
alpha = 1.
div_m = 1
self.args = [None, None, None]
self.args += [_einsum.instance.locks]
self.args += [alpha, M, N, K, div_m]
self.args += dim_m
self.args += dim_n
self.args += dim_k
self.args += dim_b
self.args += stride_a
self.args += stride_b
self.args += stride_c
# LUT for A
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
self.args += [delta_a[TK], delta_a[0]]
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
self.args += [torch.from_numpy(delta_a).cuda()]
# LUT for B
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
self.args += [delta_b[TK], delta_b[0]]
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
self.args += [torch.from_numpy(delta_b).cuda()]
# LUT for C
if sparse_c:
self.args += [delta_c]
if sparse_a or sparse_b:
width = delta_a[0] // nouter if sparse_a else delta_b[0] // nouter
self.args += [width]
# Grid
if sparse_a:
self.grid = lambda opt: [width*triton.cdiv(N, opt.d('TN')), B, opt.d('TZ')]
elif sparse_b:
self.grid = lambda opt: [width*triton.cdiv(M, opt.d('TM')), B, opt.d('TZ')]
elif sparse_c:
width = int(layout_c.sum())
self.grid = lambda opt: [width, B, opt.d('TZ')]
else:
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
triton.cdiv(N, opt.d('TN')),
triton.cdiv(B, opt.d('TB')),
opt.d('TZ')]
# position of dynamic arguments
self.pos_a = 0
self.pos_b = 1
self.pos_c = 2
# save information on the operation
self.expr_a = expr_a
self.expr_b = expr_b
self.expr_c = expr_c
self.matmul_B = B
self.matmul_M = M
self.matmul_N = N
self.matmul_K = K
# output shape
self.shape_c = shape_c
def run(self, a, b):
c = torch.empty(*self.shape_c, dtype=a.dtype, device=a.device)
self.args[self.pos_a] = a
self.args[self.pos_b] = b
self.args[self.pos_c] = c
self.kernel(*self.args, grid=self.grid)
return c
############################
## Forward
############################
instance_cache = dict()
registry = dict()
@staticmethod
def forward(ctx, expr, a, b, layouts, blocks):
# compile einsum instance
cache = _einsum.instance_cache
key = (expr, a.dtype,
a.stride(), b.stride(),
a.shape , b.shape)
if key not in cache:
cache[key] = _einsum.instance(expr, a.dtype,
a.stride(), b.stride(),
a.shape , b.shape ,
layouts, blocks)
instance = cache[key]
# run and mark as dirty c modified in-place
c = instance.run(a, b)
# save information in context
ctx.expr_a = instance.expr_a
ctx.expr_b = instance.expr_b
ctx.expr_c = instance.expr_c
ctx.matmul_B = instance.matmul_B
ctx.matmul_M = instance.matmul_M
ctx.matmul_N = instance.matmul_N
ctx.matmul_K = instance.matmul_K
ctx.save_for_backward(a, b)
return c
############################
## Backward
############################
@staticmethod
def backward(ctx, dy):
a, b = ctx.saved_tensors
expr_a = ctx.expr_a
expr_b = ctx.expr_b
expr_c = ctx.expr_c
# gradient of first argument
da = None
if ctx.needs_input_grad[1]:
da = torch.empty_like(a)
einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, da)
# gradient of second argument
db = None
if ctx.needs_input_grad[2]:
db = torch.empty_like(b)
einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db)
return None, da, db, None, None, None, None, None, None, None
def einsum(expr, a, b, layouts = None, blocks = dict()):
return _einsum.apply(expr, a, b, layouts, blocks)

View File

@@ -21,7 +21,7 @@ int main() {
// config_t{ord, x[0], x[1], 1280, 1280, 1280}, // config_t{ord, x[0], x[1], 1280, 1280, 1280},
// config_t{ord, x[0], x[1], 1536, 1536, 1536}, // config_t{ord, x[0], x[1], 1536, 1536, 1536},
// config_t{ord, x[0], x[1], 2048, 2048, 2048}, // config_t{ord, x[0], x[1], 2048, 2048, 2048},
config_t{ord, x[0], x[1], 4096, 4096, 4096}, config_t{ord, x[0], x[1], 8192, 8192, 8192},
// config_t{ord, x[0], x[1], 256, 16, 256}, // config_t{ord, x[0], x[1], 256, 16, 256},
// config_t{ord, x[0], x[1], 512, 16, 512}, // config_t{ord, x[0], x[1], 512, 16, 512},

View File

@@ -102,7 +102,7 @@ void triton_conv(drv::context* context, drv::stream* stream,
stream->write(&*ddelta, true, 0, hdelta); stream->write(&*ddelta, true, 0, hdelta);
// macros // macros
rt::function::options_space_t opt; rt::options_space_t opt;
opt.defines.push_back({"TYPE", {ty}}); opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"TM", {"128"}}); opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}}); opt.defines.push_back({"TN", {"128"}});
@@ -125,7 +125,7 @@ void triton_conv(drv::context* context, drv::stream* stream,
W*H*CI, W*H, W, 1, W*H*CI, W*H, W, 1,
CO*S*R , CO*S, CO, 1, CO*S*R , CO*S, CO, 1,
Q*P*CO, Q*P, Q, 1}; Q*P*CO, Q*P, Q, 1};
auto grid = [Z,P,Q,CO](const rt::function::options_t& x) { auto grid = [Z,P,Q,CO](const rt::options_t& x) {
return rt::grid_t{ceil(Z*P*Q, x.D<int>("TM")), return rt::grid_t{ceil(Z*P*Q, x.D<int>("TM")),
ceil(CO , x.D<int>("TN")), ceil(CO , x.D<int>("TN")),
(size_t)x.D<int>("TZ")}; (size_t)x.D<int>("TZ")};

View File

@@ -107,7 +107,7 @@ void triton_copy_nd(drv::context* context, drv::stream* stream, const std::vecto
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize)); auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize)); auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
// create options // create options
rt::function::options_space_t opt; rt::options_space_t opt;
// macros // macros

View File

@@ -101,46 +101,38 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
// ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size()); // ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
// macros // macros
rt::function::options_space_t opt; rt::options_space_t opts;
// A access patterns // A access patterns
opt.defines.push_back({"USEA", {AT? "a" : "a" }}); opts.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
opt.defines.push_back({"BROADCAST_AK", {AT? "newaxis, :" : "newaxis, :" }}); opts.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
opt.defines.push_back({"BROADCAST_AM", {AT? ":, newaxis" : ":, newaxis" }});
opt.defines.push_back({"SHAPE_A", {AT? "TM, TK" : "TM, TK" }});
opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
// B access patterns // B access patterns
opt.defines.push_back({"USEB", {BT? "b" : "b" }}); opts.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }});
opt.defines.push_back({"BROADCAST_BK", {BT? ":, newaxis" : ":, newaxis" }}); opts.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }});
opt.defines.push_back({"BROADCAST_BN", {BT? "newaxis, :" : "newaxis, :" }});
opt.defines.push_back({"SHAPE_B", {BT? "TK, TN" : "TK, TN" }});
opt.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }});
opt.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }});
// data-type // data-type
opt.defines.push_back({"TYPE", {ty}}); opts.defines.push_back({"TYPE", {ty}});
// tile sizes // tile sizes
if(mode == TEST) { if(mode == TEST) {
opt.defines.push_back({"TM", {std::to_string(TM)}}); opts.defines.push_back({"TM", {std::to_string(TM)}});
opt.defines.push_back({"TN", {std::to_string(TN)}}); opts.defines.push_back({"TN", {std::to_string(TN)}});
opt.defines.push_back({"TK", {std::to_string(TK)}}); opts.defines.push_back({"TK", {std::to_string(TK)}});
opt.defines.push_back({"TZ", {"1"}}); opts.defines.push_back({"TZ", {"1"}});
opt.num_warps = {nwarp}; opts.num_warps = {nwarp};
} }
if(mode == BENCH) { if(mode == BENCH) {
opt.defines.push_back({"TM", {"64", "128"}}); opts.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"64", "128"}}); opts.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"16"}}); opts.defines.push_back({"TK", {"32"}});
opt.defines.push_back({"TZ", {"1"}}); opts.defines.push_back({"TZ", {"1"}});
opt.num_warps = {4}; opts.num_warps = {4};
} }
// kernels // kernels
rt::function function(src::dot, opt); rt::function function(src::dot, opts);
dot_arg_t args = {da->addr_as_uintptr_t(), db->addr_as_uintptr_t(), dc->addr_as_uintptr_t(), dot_arg_t args = {da->addr_as_uintptr_t(), db->addr_as_uintptr_t(), dc->addr_as_uintptr_t(),
1, M, N, K, lda, ldb, ldc, dlocks->addr_as_uintptr_t()}; 1, M, N, K, lda, ldb, ldc, dlocks->addr_as_uintptr_t()};
auto grid = [M, N](const rt::function::options_t& x) { auto grid = [M, N](const rt::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM")), return rt::grid_t{ceil(M, x.D<int>("TM"))*
ceil(N, x.D<int>("TN")), ceil(N, x.D<int>("TN")),
(size_t)x.D<int>("TZ")}; (size_t)x.D<int>("TZ")};
}; };
@@ -151,19 +143,25 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream); double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream);
bench.push_back(tflops(triton_ns)); bench.push_back(tflops(triton_ns));
// // cublas // cublas
// if(cublas::cublasinit()){ if(cublas::cublasinit()){
// T alpha(static_cast<double>(1)); T alpha(static_cast<double>(1));
// T beta(static_cast<double>(0)); T beta(static_cast<double>(0));
// cublasGemmAlgo_t fastest; cublasGemmAlgo_t fastest;
// cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); // cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, !AT, !BT, M, N, K,
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc, &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
// ldc, nullptr, fastest); }, stream); ldc); }, stream);
// bench.push_back(tflops(cublas_ms)); bench.push_back(tflops(cublas_ms));
// } }
} }
// rt::options_t opt;
// for(auto &x: opts.defines)
// opt.defines[x.first] = x.second[0];
// opt.num_warps = 1;
// std::cout << function.get_asm(rt::ASM_NV_PTX, device, opt) << std::endl;
// test triton // test triton
if(mode == TEST){ if(mode == TEST){
srand(0); srand(0);

View File

@@ -95,25 +95,25 @@ void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vec
y_strides.push_back(y_strides[d] + " * " + y_shapename[y_order[d]]); y_strides.push_back(y_strides[d] + " * " + y_shapename[y_order[d]]);
// options // options
rt::function::options_space_t opt; rt::options_space_t opts;
opt.defines.push_back({"TYPE", {ty}}); opts.defines.push_back({"TYPE", {ty}});
for(int d = 0; d < rank_x; d++) for(int d = 0; d < rank_x; d++)
opt.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}}); opts.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}});
for(int d = 0; d < rank_y; d++) for(int d = 0; d < rank_y; d++)
opt.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}}); opts.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}});
if(TS.empty()) if(TS.empty())
TS = tile_nd(rank_x); TS = tile_nd(rank_x);
for(int d = 0; d < rank_x; d++) for(int d = 0; d < rank_x; d++)
opt.defines.push_back({"TS" + std::to_string(d), TS[d]}); opts.defines.push_back({"TS" + std::to_string(d), TS[d]});
std::vector<size_t> axy; std::vector<size_t> axy;
for(int d = 0; d < rank_x; d++) for(int d = 0; d < rank_x; d++)
if(d != axis) if(d != axis)
axy.push_back(d); axy.push_back(d);
for(int d = 0; d < rank_y; d++) for(int d = 0; d < rank_y; d++)
opt.defines.push_back({"TY" + std::to_string(d), {std::to_string(shape_x[axy[d]])}}); opts.defines.push_back({"TY" + std::to_string(d), {std::to_string(shape_x[axy[d]])}});
for(int d = 0; d < rank_y; d++) for(int d = 0; d < rank_y; d++)
opt.defines.push_back({"RY" + std::to_string(d), {"rs" + std::to_string(axy[d])}}); opts.defines.push_back({"RY" + std::to_string(d), {"rs" + std::to_string(axy[d])}});
std::string RED = ""; std::string RED = "";
for(int n = 0; n < rank_x; n++){ for(int n = 0; n < rank_x; n++){
@@ -121,30 +121,40 @@ void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vec
RED += ", "; RED += ", ";
RED += (n==axis) ? to_str(op) : ":"; RED += (n==axis) ? to_str(op) : ":";
} }
opt.defines.push_back({"RED", {RED}}); opts.defines.push_back({"RED", {RED}});
opt.num_warps = {1}; opts.num_warps = {2};
// kernel // kernel
rt::function function(src::reduce_nd[rank_x - 1], opt); rt::function function(src::reduce_nd[rank_x - 1], opts);
// input buffers // input buffers
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_x*dtsize)); auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_x*dtsize));
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_y*dtsize)); auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_y*dtsize));
// grid // grid
reduce_arg_t args = {*dx->cu(), *dy->cu(), shape_x[0]}; std::stringstream oss;
if(shape_x.size() > 1) args.S1 = shape_x[1]; rt::add_arg(oss, *dx->cu());
if(shape_x.size() > 2) args.S2 = shape_x[2]; rt::add_arg(oss, *dy->cu());
rt::add_arg(oss, (uint32_t)shape_x[0]);
if(shape_x.size() > 1) rt::add_arg(oss, (uint32_t)shape_x[1]);
if(shape_x.size() > 2) rt::add_arg(oss, (uint32_t)shape_x[2]);
std::vector<std::string> ts = {"TS0", "TS1", "TS2"}; std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
auto grid = grid_nd(shape_x, ts); auto grid = grid_nd(shape_x, ts);
// metrics // metrics
if(mode == BENCH){ if(mode == BENCH){
auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; }; auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; };
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream); double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream);
bench.push_back(gbps(triton_ns)); bench.push_back(gbps(triton_ns));
} }
// rt::options_t opt;
// for(auto &x: opts.defines)
// opt.defines[x.first] = x.second[0];
// opt.num_warps = 1;
// std::cout << function.get_asm(rt::ASM_NV_PTX, device, opt) << std::endl;
// test triton // test triton
if(mode == TEST){ if(mode == TEST){
std::vector<NumericT> hy(size_y); std::vector<NumericT> hy(size_y);
@@ -153,7 +163,7 @@ void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vec
init_zeros(hy); init_zeros(hy);
init_rand(hx); init_rand(hx);
stream->write(&*dx, true, 0, hx); stream->write(&*dx, true, 0, hx);
function((void**)&args, sizeof(args), grid, stream, device); function((void**)oss.str().data(), oss.str().size(), grid, stream, device);
stream->synchronize(); stream->synchronize();
stream->read(&*dy, true, 0, hy); stream->read(&*dy, true, 0, hy);
cc_reduce_nd(ry, hx, op, axis, shape_x); cc_reduce_nd(ry, hx, op, axis, shape_x);

View File

@@ -2,6 +2,9 @@ namespace src {
const char *dot = const char *dot =
R"( R"(
#define STM 8
#define STN 8
__global__ void dot(TYPE * A __noalias __readonly __aligned(16), __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16), TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16), TYPE * C __noalias __aligned(16),
@@ -14,54 +17,59 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
int ldc __multipleof(8), int ldc __multipleof(8),
int* locks) { int* locks) {
// prologue // prologue
int ridx = get_program_id(0); int pid = get_program_id(0);
int ridy = get_program_id(1); int pidz = get_program_id(2);
int ridz = get_program_id(2); int gridm = (M + TM - 1) / TM;
int gridx = M / TM; int gridn = (N + TN - 1) / TN;
int gridy = N / TN; int width = STM*gridn;
int rid = ridx + ridy * gridx; int stm = pid / width;
ridx = rid / gridy; int RSTM = min(gridm - stm*STM, STM);
ridy = rid % gridy; int stn = (pid % width) / (RSTM*STN);
int rm[TM] = ridx * TM + 0 ... TM; int RSTN = min(gridn - stn*STN, STN);
int rn[TN] = ridy * TN + 0 ... TN; int laneid = pid % (RSTM * RSTN);
int lanem = laneid / RSTN;
int lanen = laneid % RSTN;
int pidm = stm*STM + lanem;
int pidn = stn*STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// reduction splitting // reduction splitting
K = K / TZ; K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK; int rk[TK] = pidz * K + 0 ... TK;
// pointers to operands // pointers to operands
int offa[SHAPE_A] = rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[SHAPE_B] = rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN; int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
TYPE* pa[SHAPE_A] = A + offa; TYPE* pa[TM, TK] = A + offa;
TYPE* pb[SHAPE_B] = B + offb; TYPE* pb[TK, TN] = B + offb;
// prefetches operands // prefetches operands
bool checka[SHAPE_A] = rk[BROADCAST_AK] < K; bool checka[TM, TK] = rk[newaxis, :] < K;
bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K; bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[SHAPE_A] = checka ? *pa : 0; TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[SHAPE_B] = checkb ? *pb : 0; TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
// reduction loop // reduction loop
float acc[TM, TN] = 0; float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){ for(int k = K; k > 0; k -= TK){
acc += USEA @ USEB; bool checka[TM, TK] = k > TK;
bool checka[SHAPE_A] = k > TK; bool checkb[TK, TN] = k > TK;
bool checkb[SHAPE_B] = k > TK; acc += a @ b;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
a = *?(checka)pa; a = *?(checka)pa;
b = *?(checkb)pb; b = *?(checkb)pb;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
} }
acc = acc * alpha; acc = acc * alpha;
TYPE c[TM, TN] = acc; TYPE c[TM, TN] = acc;
// epilogue // epilogue
int rxm[TM] = get_program_id(0) * TM + 0 ... TM; int rcm[TM] = pidm * TM + 0 ... TM;
int rxn[TN] = get_program_id(1) * TN + 0 ... TN; int rcn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :]; int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
TYPE* pc[TM, TN] = C + offc; TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N); bool checkc[TM, TN] = rcm[:, newaxis] < M &&
rcn[newaxis, :] < N;
#if (TZ==1) #if (TZ==1)
*?(checkc) pc = c; *?(checkc) pc = c;
#else #else

View File

@@ -19,13 +19,13 @@ inline size_t ceil(size_t x, size_t y) {
} }
inline rt::function::grid_fn_ty grid1d(size_t N) { inline rt::function::grid_fn_ty grid1d(size_t N) {
return [N](const rt::function::options_t& x) { return [N](const rt::options_t& x) {
return rt::grid_t{ceil(N, x.D<int>("TN"))}; return rt::grid_t{ceil(N, x.D<int>("TN"))};
}; };
} }
inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) { inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
return [M, N](const rt::function::options_t& x) { return [M, N](const rt::options_t& x) {
return rt::grid_t{ceil(M, x.D<int>("TM")), return rt::grid_t{ceil(M, x.D<int>("TM")),
ceil(N, x.D<int>("TN"))}; ceil(N, x.D<int>("TN"))};
}; };
@@ -33,7 +33,7 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
inline rt::function::grid_fn_ty grid_nd(const std::vector<int32_t> &shape, inline rt::function::grid_fn_ty grid_nd(const std::vector<int32_t> &shape,
const std::vector<std::string>& ts) { const std::vector<std::string>& ts) {
return [&shape, &ts](const rt::function::options_t& x) { return [&shape, &ts](const rt::options_t& x) {
rt::grid_t ret; rt::grid_t ret;
for(size_t d = 0; d < shape.size(); d++) for(size_t d = 0; d < shape.size(); d++)
ret.push_back(ceil(shape[d], x.D<int>(ts[d]))); ret.push_back(ceil(shape[d], x.D<int>(ts[d])));
@@ -71,6 +71,12 @@ void init_zeros(std::vector<T>& x) {
x[i] = 0; x[i] = 0;
} }
template<class T>
void init_ones(std::vector<T>& x) {
for(size_t i = 0; i < x.size(); i++)
x[i] = 1;
}
/* ------------------------ /* ------------------------
* Loop Nests * Loop Nests
* ------------------------ */ * ------------------------ */
@@ -163,7 +169,7 @@ for(size_t i = 0; i < hc.size(); i++)
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
return false; return false;
} }
return true; return true;
} }
} }

View File

@@ -8,22 +8,58 @@ int main() {
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context->backend()); triton::driver::stream* stream = triton::driver::stream::create(context->backend());
// shapes to test // shapes to test
typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t; typedef std::tuple<dtype_t, bool, bool, int, int, int, int> config_t;
std::vector<config_t> configs; std::vector<config_t> configs;
for(int TM: std::vector<int>{32, 64, 128}) for(dtype_t dtype: std::vector<dtype_t>{FLOAT, HALF})
for(int TN: std::vector<int>{32, 64, 128}) for(bool AT: std::vector<bool>{false, true})
for(int TK: std::vector<int>{16}) for(bool BT: std::vector<bool>{false, true}){
for(int nwarps: std::vector<int>{4}) // 1 warp
for(bool AT: std::array<bool, 2>{false, true}) configs.push_back({dtype, AT, BT, 16, 16, 16, 1});
for(bool BT: std::array<bool, 2>{false, true}){ configs.push_back({dtype, AT, BT, 32, 16, 16, 1});
configs.push_back(config_t{FLOAT, AT, BT, TM, TN, TK, TM, TN, TK, nwarps}); configs.push_back({dtype, AT, BT, 16, 32, 16, 1});
} configs.push_back({dtype, AT, BT, 16, 16, 32, 1});
configs.push_back({dtype, AT, BT, 32, 16, 32, 1});
configs.push_back({dtype, AT, BT, 16, 32, 32, 1});
if(dtype == HALF){
configs.push_back({dtype, AT, BT, 16, 64, 16, 1});
configs.push_back({dtype, AT, BT, 16, 16, 64, 1});
configs.push_back({dtype, AT, BT, 64, 16, 64, 1});
configs.push_back({dtype, AT, BT, 16, 64, 64, 1});
}
// 2 warps
configs.push_back({dtype, AT, BT, 64, 32, 64, 2});
configs.push_back({dtype, AT, BT, 32, 64, 64, 2});
configs.push_back({dtype, AT, BT, 64, 32, 16, 2});
configs.push_back({dtype, AT, BT, 32, 64, 16, 2});
configs.push_back({dtype, AT, BT, 128, 32, 32, 2});
configs.push_back({dtype, AT, BT, 32, 128, 32, 2});
// 4 warps
configs.push_back({dtype, AT, BT, 128, 64, 16, 4});
configs.push_back({dtype, AT, BT, 64, 128, 16, 4});
configs.push_back({dtype, AT, BT, 128, 32, 32, 4});
configs.push_back({dtype, AT, BT, 32, 128, 32, 4});
if(dtype == HALF){
configs.push_back({dtype, AT, BT, 128, 32, 64, 4});
configs.push_back({dtype, AT, BT, 32, 128, 64, 4});
}
// 8 warps
configs.push_back({dtype, AT, BT, 128, 256, 16, 8});
configs.push_back({dtype, AT, BT, 256, 128, 16, 8});
if(dtype == HALF){
configs.push_back({dtype, AT, BT, 256, 128, 32, 8});
configs.push_back({dtype, AT, BT, 256, 128, 32, 8});
}
};
// test // test
dtype_t dtype; dtype_t dtype;
bool AT, BT; bool AT, BT;
int M, N, K, TM, TN, TK, nwarp; int M, N, K, TM, TN, TK, nwarp;
for(const auto& c: configs){ for(const auto& c: configs){
std::tie(dtype, AT, BT, M, N, K, TM, TN, TK, nwarp) = c; std::tie(dtype, AT, BT, TM, TN, TK, nwarp) = c;
M = TM;
N = TN;
K = TK;
std::cout << "Testing " << c << " ... " << std::flush; std::cout << "Testing " << c << " ... " << std::flush;
if(test_dot(context, stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp)) if(test_dot(context, stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp))
std::cout << " Pass! " << std::endl; std::cout << " Pass! " << std::endl;

View File

@@ -20,12 +20,15 @@ int main() {
// shapes to benchmark // shapes to benchmark
typedef std::tuple<std::vector<int>, int, reduce_op_t> config_t; typedef std::tuple<std::vector<int>, int, reduce_op_t> config_t;
std::vector<config_t> configs = { std::vector<config_t> configs = {
config_t{{8, 8, 4}, 2, ADD}, config_t{{64}, 0, ADD},
config_t{{32}, 0, ADD}, config_t{{128}, 0, MIN},
config_t{{32, 32}, 0, MAX}, config_t{{32, 32}, 0, MAX},
config_t{{32, 32}, 1, ADD}, config_t{{32, 32}, 1, ADD},
config_t{{32, 64}, 0, ADD}, config_t{{32, 64}, 0, ADD},
config_t{{64, 32}, 1, ADD} config_t{{64, 32}, 1, ADD},
config_t{{8, 8, 4}, 2, ADD},
config_t{{8, 8, 4}, 0, ADD},
config_t{{8, 8, 4}, 1, ADD}
}; };
// does the work // does the work
int axis; int axis;