[GENERAL] Merged v1.0alpha into master. Added features are:
- A100 support via mma.16816 - Thread swizzling for conflict-free shared memory accesses without padding - Complete overhaul of the LLVM code generation in codegen/selection/generator.cc to remove overengineering - Added debugging capabilities in the Python binding - Compilation error for kernels that spill
This commit is contained in:
@@ -28,7 +28,8 @@
|
|||||||
# We also want an user-specified LLVM_ROOT_DIR to take precedence over the
|
# We also want an user-specified LLVM_ROOT_DIR to take precedence over the
|
||||||
# system default locations such as /usr/local/bin. Executing find_program()
|
# system default locations such as /usr/local/bin. Executing find_program()
|
||||||
# multiples times is the approach recommended in the docs.
|
# multiples times is the approach recommended in the docs.
|
||||||
set(llvm_config_names llvm-config-10 llvm-config-10.0 llvm-config100
|
set(llvm_config_names llvm-config-11 llvm-config-11.0
|
||||||
|
llvm-config-10 llvm-config-10.0 llvm-config100
|
||||||
llvm-config-9 llvm-config-9.0 llvm-config90
|
llvm-config-9 llvm-config-9.0 llvm-config90
|
||||||
llvm-config-8 llvm-config-8.0 llvm-config80
|
llvm-config-8 llvm-config-8.0 llvm-config80
|
||||||
llvm-config)
|
llvm-config)
|
||||||
|
@@ -27,7 +27,7 @@ private:
|
|||||||
void update_graph_trans(ir::instruction *i);
|
void update_graph_trans(ir::instruction *i);
|
||||||
void update_graph_broadcast(ir::instruction *i);
|
void update_graph_broadcast(ir::instruction *i);
|
||||||
void update_graph_dot(ir::instruction *i);
|
void update_graph_dot(ir::instruction *i);
|
||||||
void update_graph_elementwise(ir::instruction *i);
|
void update_graph_elementwise(ir::instruction *i, bool connect_ret=true);
|
||||||
void update_graph_no_edge(ir::instruction *i);
|
void update_graph_no_edge(ir::instruction *i);
|
||||||
void update_graph(ir::instruction *i);
|
void update_graph(ir::instruction *i);
|
||||||
|
|
||||||
|
@@ -25,7 +25,7 @@ class axes;
|
|||||||
class align;
|
class align;
|
||||||
class layout_visitor;
|
class layout_visitor;
|
||||||
class data_layout;
|
class data_layout;
|
||||||
class mma884_layout;
|
class mma_layout;
|
||||||
class scanline_layout;
|
class scanline_layout;
|
||||||
class shared_layout;
|
class shared_layout;
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ class shared_layout;
|
|||||||
class layout_visitor {
|
class layout_visitor {
|
||||||
public:
|
public:
|
||||||
virtual void visit_layout(data_layout *);
|
virtual void visit_layout(data_layout *);
|
||||||
virtual void visit_layout_hmma_884(mma884_layout*) = 0;
|
virtual void visit_layout_mma(mma_layout*) = 0;
|
||||||
virtual void visit_layout_scanline(scanline_layout*) = 0;
|
virtual void visit_layout_scanline(scanline_layout*) = 0;
|
||||||
virtual void visit_layout_shared(shared_layout*) = 0;
|
virtual void visit_layout_shared(shared_layout*) = 0;
|
||||||
};
|
};
|
||||||
@@ -41,7 +41,7 @@ public:
|
|||||||
class data_layout {
|
class data_layout {
|
||||||
protected:
|
protected:
|
||||||
enum id_t {
|
enum id_t {
|
||||||
HMMA_884,
|
MMA,
|
||||||
SCANLINE,
|
SCANLINE,
|
||||||
SHARED
|
SHARED
|
||||||
};
|
};
|
||||||
@@ -68,7 +68,7 @@ public:
|
|||||||
// visitor
|
// visitor
|
||||||
virtual void accept(layout_visitor* vst) = 0;
|
virtual void accept(layout_visitor* vst) = 0;
|
||||||
// downcast
|
// downcast
|
||||||
mma884_layout* to_mma884() { return downcast<mma884_layout>(HMMA_884); }
|
mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
|
||||||
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
|
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
|
||||||
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
|
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
|
||||||
// accessors
|
// accessors
|
||||||
@@ -77,9 +77,10 @@ public:
|
|||||||
const order_t& get_order() const { return order_; }
|
const order_t& get_order() const { return order_; }
|
||||||
const values_t& get_values() const { return values_;}
|
const values_t& get_values() const { return values_;}
|
||||||
int get_axis(size_t k) const { return axes_.at(k); }
|
int get_axis(size_t k) const { return axes_.at(k); }
|
||||||
|
std::vector<int> get_axes() const { return axes_; }
|
||||||
const int get_order(size_t k) const { return order_.at(k); }
|
const int get_order(size_t k) const { return order_.at(k); }
|
||||||
// find the position of given axis
|
// find the position of given axis
|
||||||
size_t find_axis(int to_find) const;
|
int find_axis(int to_find) const;
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -92,21 +93,29 @@ protected:
|
|||||||
shape_t shape_;
|
shape_t shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class mma884_layout: public data_layout {
|
class mma_layout: public data_layout {
|
||||||
public:
|
public:
|
||||||
mma884_layout(size_t num_warps,
|
mma_layout(size_t num_warps,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const std::vector<unsigned>& shapes,
|
const std::vector<unsigned>& shapes,
|
||||||
const std::vector<ir::value *> &values,
|
const std::vector<ir::value *> &values,
|
||||||
analysis::align* align);
|
analysis::align* align, target *tgt,
|
||||||
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
|
shared_layout* layout_a,
|
||||||
|
shared_layout* layout_b);
|
||||||
|
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
|
||||||
// accessor
|
// accessor
|
||||||
int fpw(size_t k) { return fpw_.at(k); }
|
int fpw(size_t k) { return fpw_.at(k); }
|
||||||
int wpt(size_t k) { return wpt_.at(k); }
|
int wpt(size_t k) { return wpt_.at(k); }
|
||||||
|
int spw(size_t k) { return spw_.at(k); }
|
||||||
|
int spt(size_t k) { return spt_.at(k); }
|
||||||
|
int rep(size_t k) { return rep_.at(k); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> fpw_;
|
std::vector<int> fpw_;
|
||||||
|
std::vector<int> spw_;
|
||||||
std::vector<int> wpt_;
|
std::vector<int> wpt_;
|
||||||
|
std::vector<int> spt_;
|
||||||
|
std::vector<int> rep_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct scanline_layout: public data_layout {
|
struct scanline_layout: public data_layout {
|
||||||
@@ -138,7 +147,7 @@ private:
|
|||||||
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
|
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
shared_layout(const data_layout *arg,
|
shared_layout(data_layout *arg,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const std::vector<unsigned>& shapes,
|
const std::vector<unsigned>& shapes,
|
||||||
const std::vector<ir::value *> &values_,
|
const std::vector<ir::value *> &values_,
|
||||||
@@ -149,11 +158,22 @@ public:
|
|||||||
size_t get_size() { return size_; }
|
size_t get_size() { return size_; }
|
||||||
ir::type* get_type() { return ty_; }
|
ir::type* get_type() { return ty_; }
|
||||||
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
|
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
|
||||||
|
size_t get_num_per_phase() { return num_per_phase_; }
|
||||||
|
ir::value* hmma_dot_a() { return hmma_dot_a_; }
|
||||||
|
ir::value* hmma_dot_b() { return hmma_dot_b_; }
|
||||||
|
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
||||||
|
int get_mma_vec() { return mma_vec_;}
|
||||||
|
data_layout* get_arg_layout() { return arg_layout_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t size_;
|
size_t size_;
|
||||||
ir::type *ty_;
|
ir::type *ty_;
|
||||||
std::shared_ptr<double_buffer_info_t> double_buffer_;
|
std::shared_ptr<double_buffer_info_t> double_buffer_;
|
||||||
|
size_t num_per_phase_;
|
||||||
|
ir::value* hmma_dot_a_;
|
||||||
|
ir::value* hmma_dot_b_;
|
||||||
|
data_layout* arg_layout_;
|
||||||
|
int mma_vec_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
43
include/triton/codegen/analysis/swizzle.h
Normal file
43
include/triton/codegen/analysis/swizzle.h
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||||
|
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
namespace ir{
|
||||||
|
class module;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace codegen{
|
||||||
|
class target;
|
||||||
|
|
||||||
|
namespace analysis{
|
||||||
|
|
||||||
|
class layouts;
|
||||||
|
class data_layout;
|
||||||
|
|
||||||
|
class swizzle {
|
||||||
|
public:
|
||||||
|
// constructor
|
||||||
|
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
|
||||||
|
// accessors
|
||||||
|
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
|
||||||
|
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
|
||||||
|
int get_vec (data_layout* layout) { return vec_.at(layout); }
|
||||||
|
// run
|
||||||
|
void run(ir::module &mod);
|
||||||
|
private:
|
||||||
|
layouts* layouts_;
|
||||||
|
target* tgt_;
|
||||||
|
std::map<data_layout*, int> per_phase_;
|
||||||
|
std::map<data_layout*, int> max_phase_;
|
||||||
|
std::map<data_layout*, int> vec_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
@@ -5,13 +5,14 @@
|
|||||||
|
|
||||||
#include "triton/ir/visitor.h"
|
#include "triton/ir/visitor.h"
|
||||||
#include "triton/codegen/analysis/layout.h"
|
#include "triton/codegen/analysis/layout.h"
|
||||||
#include "triton/codegen/selection/machine_value.h"
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
// forward
|
// forward
|
||||||
namespace llvm{
|
namespace llvm{
|
||||||
class Type;
|
class Type;
|
||||||
class Value;
|
class Value;
|
||||||
|
class BasicBlock;
|
||||||
|
class Attribute;
|
||||||
class Instruction;
|
class Instruction;
|
||||||
class Constant;
|
class Constant;
|
||||||
class LLVMContext;
|
class LLVMContext;
|
||||||
@@ -25,6 +26,13 @@ namespace llvm{
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
|
|
||||||
|
namespace ir{
|
||||||
|
class attribute;
|
||||||
|
class load_inst;
|
||||||
|
class store_inst;
|
||||||
|
}
|
||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
|
||||||
// forward
|
// forward
|
||||||
@@ -36,6 +44,7 @@ class allocation;
|
|||||||
class cts;
|
class cts;
|
||||||
class axes;
|
class axes;
|
||||||
class layouts;
|
class layouts;
|
||||||
|
class swizzle;
|
||||||
}
|
}
|
||||||
// typedef
|
// typedef
|
||||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||||
@@ -43,17 +52,14 @@ typedef llvm::IRBuilder<llvm::ConstantFolder,
|
|||||||
typedef llvm::LLVMContext LLVMContext;
|
typedef llvm::LLVMContext LLVMContext;
|
||||||
typedef llvm::Type Type;
|
typedef llvm::Type Type;
|
||||||
typedef llvm::Value Value;
|
typedef llvm::Value Value;
|
||||||
|
typedef llvm::Attribute Attribute;
|
||||||
|
typedef llvm::BasicBlock BasicBlock;
|
||||||
typedef llvm::Module Module;
|
typedef llvm::Module Module;
|
||||||
typedef llvm::Instruction Instruction;
|
typedef llvm::Instruction Instruction;
|
||||||
typedef llvm::Constant Constant;
|
typedef llvm::Constant Constant;
|
||||||
typedef llvm::ArrayType ArrayType;
|
typedef llvm::ArrayType ArrayType;
|
||||||
typedef llvm::Function Function;
|
typedef llvm::Function Function;
|
||||||
typedef std::vector<Value*> indices_t;
|
typedef std::vector<Value*> indices_t;
|
||||||
// forward
|
|
||||||
class machine_data_layout;
|
|
||||||
class tile;
|
|
||||||
class shared_tile;
|
|
||||||
class distributed_tile;
|
|
||||||
class target;
|
class target;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -62,110 +68,129 @@ class target;
|
|||||||
namespace triton{
|
namespace triton{
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
|
||||||
|
struct distributed_axis {
|
||||||
|
int contiguous;
|
||||||
|
std::vector<Value*> values;
|
||||||
|
Value* thread_id;
|
||||||
|
};
|
||||||
|
|
||||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||||
private:
|
private:
|
||||||
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
|
void init_idx(ir::value *x);
|
||||||
Value* get_value(ir::value *x, const indices_t& idx);
|
Instruction* add_barrier();
|
||||||
void set_value(ir::value *x, const indices_t& idx, Value* v);
|
Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
|
||||||
|
|
||||||
void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
|
|
||||||
void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
|
|
||||||
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
|
|
||||||
Type *c_ty, Function *f_mul_add);
|
|
||||||
|
|
||||||
void finalize_shared_layout(analysis::shared_layout*);
|
void finalize_shared_layout(analysis::shared_layout*);
|
||||||
void finalize_function(ir::function*);
|
void finalize_function(ir::function*);
|
||||||
void finalize_phi_node(ir::phi_node*);
|
void finalize_phi_node(ir::phi_node*);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Type *cvt(ir::type *ty);
|
||||||
|
llvm::Attribute cvt(ir::attribute attr);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
generator(analysis::axes *a_axes,
|
generator(analysis::axes *a_axes,
|
||||||
analysis::layouts *layouts,
|
analysis::layouts *layouts,
|
||||||
analysis::align *alignment,
|
analysis::align *alignment,
|
||||||
analysis::allocation *alloc,
|
analysis::allocation *alloc,
|
||||||
|
analysis::swizzle *swizzle,
|
||||||
target *tgt,
|
target *tgt,
|
||||||
unsigned num_warps);
|
unsigned num_warps);
|
||||||
|
|
||||||
void visit_value(ir::value* v);
|
void visit_value(ir::value* v);
|
||||||
|
|
||||||
void visit_phi_node(ir::phi_node*);
|
void visit_phi_node(ir::phi_node*);
|
||||||
void visit_binary_operator(ir::binary_operator*);
|
void visit_binary_operator(ir::binary_operator*);
|
||||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||||
|
|
||||||
void visit_icmp_inst(ir::icmp_inst*);
|
void visit_icmp_inst(ir::icmp_inst*);
|
||||||
void visit_fcmp_inst(ir::fcmp_inst*);
|
void visit_fcmp_inst(ir::fcmp_inst*);
|
||||||
void visit_cast_inst(ir::cast_inst*);
|
void visit_cast_inst(ir::cast_inst*);
|
||||||
|
|
||||||
void visit_return_inst(ir::return_inst*);
|
void visit_return_inst(ir::return_inst*);
|
||||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||||
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
|
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
|
||||||
|
void visit_load_inst(ir::load_inst*);
|
||||||
|
|
||||||
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
|
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
|
||||||
void visit_masked_load_inst(ir::masked_load_inst*);
|
void visit_masked_load_inst(ir::masked_load_inst*);
|
||||||
|
void visit_store_inst(ir::store_inst*);
|
||||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||||
|
|
||||||
void visit_reshape_inst(ir::reshape_inst*);
|
void visit_reshape_inst(ir::reshape_inst*);
|
||||||
void visit_splat_inst(ir::splat_inst*);
|
void visit_splat_inst(ir::splat_inst*);
|
||||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||||
void visit_downcast_inst(ir::downcast_inst*);
|
void visit_downcast_inst(ir::downcast_inst*);
|
||||||
|
|
||||||
void visit_exp_inst(ir::exp_inst*);
|
void visit_exp_inst(ir::exp_inst*);
|
||||||
void visit_log_inst(ir::log_inst*);
|
void visit_log_inst(ir::log_inst*);
|
||||||
|
|
||||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||||
void visit_get_num_program_inst(ir::get_num_program_inst*);
|
void visit_get_num_program_inst(ir::get_num_program_inst*);
|
||||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||||
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
||||||
void visit_atomic_add_inst(ir::atomic_add_inst*);
|
void visit_atomic_add_inst(ir::atomic_add_inst*);
|
||||||
|
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||||
|
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||||
|
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||||
void visit_dot_inst(ir::dot_inst*);
|
void visit_dot_inst(ir::dot_inst*);
|
||||||
void visit_trans_inst(ir::trans_inst*);
|
void visit_trans_inst(ir::trans_inst*);
|
||||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||||
|
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||||
|
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||||
void visit_reduce_inst(ir::reduce_inst*);
|
void visit_reduce_inst(ir::reduce_inst*);
|
||||||
void visit_select_inst(ir::select_inst*);
|
void visit_select_inst(ir::select_inst*);
|
||||||
|
|
||||||
void visit_recoalesce_inst(ir::recoalesce_inst*);
|
void visit_recoalesce_inst(ir::recoalesce_inst*);
|
||||||
|
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
|
||||||
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
||||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||||
void visit_barrier_inst(ir::barrier_inst*);
|
void visit_barrier_inst(ir::barrier_inst*);
|
||||||
|
void visit_async_wait_inst(ir::async_wait_inst*);
|
||||||
void visit_make_range_dyn(ir::make_range_dyn*);
|
void visit_make_range_dyn(ir::make_range_dyn*);
|
||||||
void visit_make_range(ir::make_range*);
|
void visit_make_range(ir::make_range*);
|
||||||
|
|
||||||
void visit_make_range_sta(ir::make_range_sta*);
|
void visit_make_range_sta(ir::make_range_sta*);
|
||||||
void visit_undef_value(ir::undef_value*);
|
void visit_undef_value(ir::undef_value*);
|
||||||
void visit_constant_int(ir::constant_int*);
|
void visit_constant_int(ir::constant_int*);
|
||||||
void visit_constant_fp(ir::constant_fp*);
|
void visit_constant_fp(ir::constant_fp*);
|
||||||
void visit_alloc_const(ir::alloc_const*);
|
void visit_alloc_const(ir::alloc_const*);
|
||||||
|
|
||||||
void visit_function(ir::function*);
|
void visit_function(ir::function*);
|
||||||
void visit_basic_block(ir::basic_block*);
|
void visit_basic_block(ir::basic_block*);
|
||||||
void visit_argument(ir::argument*);
|
void visit_argument(ir::argument*);
|
||||||
|
void visit(ir::module &, llvm::Module &);
|
||||||
|
|
||||||
void visit_layout_hmma_884(analysis::mma884_layout*);
|
// layouts
|
||||||
|
void visit_layout_mma(analysis::mma_layout*);
|
||||||
void visit_layout_scanline(analysis::scanline_layout*);
|
void visit_layout_scanline(analysis::scanline_layout*);
|
||||||
void visit_layout_shared(analysis::shared_layout*);
|
void visit_layout_shared(analysis::shared_layout*);
|
||||||
|
|
||||||
void visit(ir::module &, llvm::Module &);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LLVMContext *ctx_;
|
LLVMContext *ctx_;
|
||||||
Builder* builder_;
|
Builder* builder_;
|
||||||
Module *mod_;
|
Module *mod_;
|
||||||
|
|
||||||
std::map<const analysis::data_layout*, machine_data_layout*> machine_layouts_;
|
|
||||||
analysis::axes *a_axes_;
|
analysis::axes *a_axes_;
|
||||||
|
analysis::swizzle *swizzle_;
|
||||||
std::map<unsigned, distributed_axis> axes_;
|
std::map<unsigned, distributed_axis> axes_;
|
||||||
std::map<ir::value *, Value *> vmap_;
|
|
||||||
std::map<ir::value *, tile *> tmap_;
|
|
||||||
target *tgt_;
|
target *tgt_;
|
||||||
analysis::layouts *layouts_;
|
analysis::layouts *layouts_;
|
||||||
analysis::align *alignment_;
|
analysis::align *alignment_;
|
||||||
analysis::allocation *alloc_;
|
analysis::allocation *alloc_;
|
||||||
Value *sh_mem_ptr_;
|
Value *shmem_;
|
||||||
unsigned num_warps_;
|
unsigned num_warps_;
|
||||||
|
|
||||||
std::set<ir::value*> seen_;
|
std::set<ir::value*> seen_;
|
||||||
|
|
||||||
|
std::map<analysis::data_layout*, Value*> offset_a_m_;
|
||||||
|
std::map<analysis::data_layout*, Value*> offset_a_k_;
|
||||||
|
std::map<analysis::data_layout*, Value*> offset_b_k_;
|
||||||
|
std::map<analysis::data_layout*, Value*> offset_b_n_;
|
||||||
|
|
||||||
|
std::map<analysis::data_layout*, Value*> shared_ptr_;
|
||||||
|
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
|
||||||
|
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
|
||||||
|
std::map<analysis::data_layout*, Value*> shared_off_;
|
||||||
|
|
||||||
|
|
||||||
|
std::map<ir::value*, Value*> shmems_;
|
||||||
|
std::map<ir::value*, Value*> shoffs_;
|
||||||
|
std::map<ir::value*, std::vector<indices_t>> idxs_;
|
||||||
|
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
|
||||||
|
std::map<ir::value*, BasicBlock *> bbs_;
|
||||||
|
std::map<ir::value*, std::vector<int>> ords_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -1,138 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#ifndef _TRITON_SELECTION_MACHINE_LAYOUT_H_
|
|
||||||
#define _TRITON_SELECTION_MACHINE_LAYOUT_H_
|
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include "triton/codegen/analysis/layout.h"
|
|
||||||
|
|
||||||
namespace llvm{
|
|
||||||
class Type;
|
|
||||||
class Value;
|
|
||||||
class Instruction;
|
|
||||||
class Constant;
|
|
||||||
class LLVMContext;
|
|
||||||
class Module;
|
|
||||||
class ConstantFolder;
|
|
||||||
class IRBuilderDefaultInserter;
|
|
||||||
template <typename T, typename Inserter>
|
|
||||||
class IRBuilder;
|
|
||||||
class ArrayType;
|
|
||||||
class Function;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace triton{
|
|
||||||
|
|
||||||
namespace ir{
|
|
||||||
class value;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace codegen{
|
|
||||||
|
|
||||||
namespace analysis{
|
|
||||||
class liveness;
|
|
||||||
class tiles;
|
|
||||||
class align;
|
|
||||||
class allocation;
|
|
||||||
class cts;
|
|
||||||
class axes;
|
|
||||||
class layouts;
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
|
||||||
llvm::IRBuilderDefaultInserter> Builder;
|
|
||||||
typedef llvm::LLVMContext LLVMContext;
|
|
||||||
typedef llvm::Type Type;
|
|
||||||
typedef llvm::Value Value;
|
|
||||||
typedef llvm::Module Module;
|
|
||||||
typedef llvm::Instruction Instruction;
|
|
||||||
typedef llvm::Constant Constant;
|
|
||||||
typedef llvm::ArrayType ArrayType;
|
|
||||||
typedef llvm::Function Function;
|
|
||||||
|
|
||||||
class distributed_axis;
|
|
||||||
class machine_data_layout;
|
|
||||||
class tile;
|
|
||||||
class shared_tile;
|
|
||||||
class distributed_tile;
|
|
||||||
class target;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace triton{
|
|
||||||
namespace codegen{
|
|
||||||
|
|
||||||
|
|
||||||
class machine_data_layout {
|
|
||||||
public:
|
|
||||||
virtual tile* create(ir::value *v) = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class machine_shared_layout: public machine_data_layout {
|
|
||||||
public:
|
|
||||||
machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr,
|
|
||||||
analysis::shared_layout* layout,
|
|
||||||
std::map<ir::value *, Value *>& vmap,
|
|
||||||
std::map<ir::value *, tile *>& tmap);
|
|
||||||
|
|
||||||
tile* create(ir::value *v);
|
|
||||||
|
|
||||||
Module *mod_;
|
|
||||||
Builder *builder_;
|
|
||||||
target *tgt_;
|
|
||||||
analysis::allocation* alloc_;
|
|
||||||
Value *&sh_mem_ptr_;
|
|
||||||
analysis::shared_layout* layout_;
|
|
||||||
std::map<ir::value *, Value *>& vmap_;
|
|
||||||
std::map<ir::value *, tile *>& tmap_;
|
|
||||||
|
|
||||||
Value *offset_;
|
|
||||||
Value *ptr_;
|
|
||||||
Value *pre_ptr_;
|
|
||||||
Value *next_ptr_;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
class machine_distributed_layout: public machine_data_layout {
|
|
||||||
public:
|
|
||||||
machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
|
|
||||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
|
||||||
analysis::data_layout* layout);
|
|
||||||
|
|
||||||
tile* create(ir::value *v);
|
|
||||||
Module *mod_;
|
|
||||||
Builder *builder_;
|
|
||||||
target *tgt_;
|
|
||||||
analysis::axes *a_axes_;
|
|
||||||
std::map<unsigned, distributed_axis>& axes_;
|
|
||||||
analysis::data_layout* layout_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class machine_mma884_layout: public machine_distributed_layout {
|
|
||||||
public:
|
|
||||||
machine_mma884_layout(Module *mod, Builder *builder,
|
|
||||||
target *tgt,
|
|
||||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
|
||||||
analysis::mma884_layout* layout);
|
|
||||||
Value *offset_a_i_, *offset_a_k_;
|
|
||||||
Value *offset_b_j_, *offset_b_k_;
|
|
||||||
unsigned pack_size_0_;
|
|
||||||
unsigned pack_size_1_;
|
|
||||||
unsigned num_packs_0_;
|
|
||||||
unsigned num_packs_1_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class machine_scanline_layout: public machine_distributed_layout {
|
|
||||||
public:
|
|
||||||
machine_scanline_layout(Module *mod, Builder *builder,
|
|
||||||
target *tgt,
|
|
||||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
|
||||||
analysis::scanline_layout* layout);
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
@@ -1,152 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#ifndef _TRITON_SELECTION_MACHINE_VALUE_H_
|
|
||||||
#define _TRITON_SELECTION_MACHINE_VALUE_H_
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
|
||||||
#include <functional>
|
|
||||||
|
|
||||||
namespace llvm{
|
|
||||||
class Type;
|
|
||||||
class Value;
|
|
||||||
class Instruction;
|
|
||||||
class Constant;
|
|
||||||
class LLVMContext;
|
|
||||||
class Module;
|
|
||||||
class ConstantFolder;
|
|
||||||
class IRBuilderDefaultInserter;
|
|
||||||
template <typename T, typename Inserter>
|
|
||||||
class IRBuilder;
|
|
||||||
class ArrayType;
|
|
||||||
class Function;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace triton{
|
|
||||||
namespace codegen{
|
|
||||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
|
||||||
llvm::IRBuilderDefaultInserter> Builder;
|
|
||||||
typedef llvm::LLVMContext LLVMContext;
|
|
||||||
typedef llvm::Type Type;
|
|
||||||
typedef llvm::Value Value;
|
|
||||||
typedef llvm::Module Module;
|
|
||||||
typedef llvm::Instruction Instruction;
|
|
||||||
typedef llvm::Constant Constant;
|
|
||||||
typedef llvm::ArrayType ArrayType;
|
|
||||||
typedef llvm::Function Function;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace triton{
|
|
||||||
namespace codegen{
|
|
||||||
|
|
||||||
namespace analysis{
|
|
||||||
class liveness;
|
|
||||||
class tiles;
|
|
||||||
class align;
|
|
||||||
class allocation;
|
|
||||||
class cts;
|
|
||||||
class axes;
|
|
||||||
class layouts;
|
|
||||||
}
|
|
||||||
|
|
||||||
class distributed_axis;
|
|
||||||
class machine_data_layout;
|
|
||||||
class tile;
|
|
||||||
class shared_tile;
|
|
||||||
class distributed_tile;
|
|
||||||
class target;
|
|
||||||
typedef std::vector<Value*> indices_t;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace triton{
|
|
||||||
namespace codegen{
|
|
||||||
|
|
||||||
struct distributed_axis {
|
|
||||||
int contiguous;
|
|
||||||
std::vector<Value*> values;
|
|
||||||
Value* thread_id;
|
|
||||||
};
|
|
||||||
|
|
||||||
class tile {
|
|
||||||
protected:
|
|
||||||
typedef std::vector<unsigned> shapes_t;
|
|
||||||
|
|
||||||
public:
|
|
||||||
tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
|
|
||||||
virtual void set_value(indices_t idx, Value *v) = 0;
|
|
||||||
virtual Value* get_value(indices_t idx) = 0;
|
|
||||||
Type *get_ty() const { return ty_; }
|
|
||||||
shapes_t get_shapes() const { return shapes_; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Type *ty_;
|
|
||||||
shapes_t shapes_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class shared_tile: public tile {
|
|
||||||
private:
|
|
||||||
void extract_constant(Value *arg, Value *&non_cst, Value *&cst);
|
|
||||||
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
|
|
||||||
|
|
||||||
|
|
||||||
public:
|
|
||||||
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector<int>& perm = {});
|
|
||||||
void set_vector_size(unsigned vector_size);
|
|
||||||
void set_return_mode(bool return_vector);
|
|
||||||
void set_value(indices_t, Value *);
|
|
||||||
Value* get_ptr_to(indices_t idx);
|
|
||||||
Value* get_value(indices_t idx);
|
|
||||||
Value* get_pointer() { return ptr_; }
|
|
||||||
Value* get_offset() { return offset_; }
|
|
||||||
const std::vector<int>& get_perm() { return perm_; }
|
|
||||||
const std::vector<int>& get_order() { return order_; }
|
|
||||||
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx);
|
|
||||||
|
|
||||||
private:
|
|
||||||
Value *ptr_;
|
|
||||||
bool return_vector_;
|
|
||||||
Builder &builder_;
|
|
||||||
Value *offset_;
|
|
||||||
std::map<indices_t, Value*> ptr_cache_;
|
|
||||||
unsigned vector_size_;
|
|
||||||
std::vector<int> order_;
|
|
||||||
std::vector<int> perm_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Distribtued tile
|
|
||||||
class distributed_tile: public tile{
|
|
||||||
typedef std::vector<distributed_axis> axes_t;
|
|
||||||
typedef std::vector<indices_t> ordered_indices_vec_t;
|
|
||||||
typedef std::map<indices_t, unsigned> indices_map_t;
|
|
||||||
typedef std::map<indices_t, Value*> values_map_t;
|
|
||||||
|
|
||||||
private:
|
|
||||||
void init_indices();
|
|
||||||
|
|
||||||
public:
|
|
||||||
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder);
|
|
||||||
void set_value(indices_t idx, Value *v);
|
|
||||||
Value* get_value(indices_t idx);
|
|
||||||
const std::vector<int>& get_order() { return order_; }
|
|
||||||
unsigned get_linear_index(indices_t idx);
|
|
||||||
indices_t get_ordered_indices(unsigned id);
|
|
||||||
void for_each(std::function<void(indices_t)> fn, int start = 0, int end = -1);
|
|
||||||
void for_each(std::function<void(indices_t)> fn, std::vector<int> start, std::vector<int> size);
|
|
||||||
|
|
||||||
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
|
|
||||||
private:
|
|
||||||
axes_t axes_;
|
|
||||||
std::vector<int> order_;
|
|
||||||
indices_map_t indices_;
|
|
||||||
values_map_t values_;
|
|
||||||
ordered_indices_vec_t ordered_indices_;
|
|
||||||
Builder &builder_;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
@@ -35,6 +35,8 @@ namespace codegen{
|
|||||||
namespace triton{
|
namespace triton{
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
|
||||||
|
class nvidia_cu_target;
|
||||||
|
|
||||||
class target {
|
class target {
|
||||||
public:
|
public:
|
||||||
target(bool is_gpu): is_gpu_(is_gpu){}
|
target(bool is_gpu): is_gpu_(is_gpu){}
|
||||||
@@ -47,6 +49,7 @@ public:
|
|||||||
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
|
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||||
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
|
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
|
||||||
virtual unsigned guaranteed_alignment() = 0;
|
virtual unsigned guaranteed_alignment() = 0;
|
||||||
|
nvidia_cu_target* as_nvidia();
|
||||||
bool is_gpu() const;
|
bool is_gpu() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -68,7 +71,7 @@ public:
|
|||||||
|
|
||||||
class nvidia_cu_target: public target {
|
class nvidia_cu_target: public target {
|
||||||
public:
|
public:
|
||||||
nvidia_cu_target(): target(true){}
|
nvidia_cu_target(int sm): target(true), sm_(sm){}
|
||||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||||
Instruction* add_barrier(Module *module, Builder& builder);
|
Instruction* add_barrier(Module *module, Builder& builder);
|
||||||
Instruction* add_memfence(Module *module, Builder& builder);
|
Instruction* add_memfence(Module *module, Builder& builder);
|
||||||
@@ -76,7 +79,11 @@ public:
|
|||||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||||
|
int sm() { return sm_; }
|
||||||
unsigned guaranteed_alignment() { return 16; }
|
unsigned guaranteed_alignment() { return 16; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int sm_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class cpu_target: public target {
|
class cpu_target: public target {
|
||||||
|
@@ -11,14 +11,22 @@ namespace ir {
|
|||||||
class value;
|
class value;
|
||||||
class phi_node;
|
class phi_node;
|
||||||
class instruction;
|
class instruction;
|
||||||
|
class builder;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
namespace transform{
|
namespace transform{
|
||||||
|
|
||||||
class cts {
|
class cts {
|
||||||
|
private:
|
||||||
|
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
cts(bool use_async = false): use_async_(use_async) {}
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool use_async_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
|
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||||
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
|
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
namespace ir {
|
namespace ir {
|
||||||
@@ -31,14 +33,14 @@ private:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
interval_vec_t join(const std::vector<interval_vec_t>& intervals);
|
interval_vec_t join(const std::vector<interval_vec_t>& intervals);
|
||||||
void insert_barrier(ir::instruction *instr, ir::builder &builder);
|
void insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder);
|
||||||
bool intersect(const interval_vec_t &X, interval_t x);
|
bool intersect(const interval_vec_t &X, interval_t x);
|
||||||
bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
|
bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
|
||||||
void add_reference(ir::value *v, interval_vec_t &res);
|
void add_reference(ir::value *v, interval_vec_t &res);
|
||||||
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
|
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
|
||||||
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
|
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
|
||||||
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
|
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
|
||||||
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
|
std::map<triton::ir::instruction *, std::pair<bool, bool> > &insert_loc, std::set<triton::ir::value *> &safe_war, std::vector<triton::ir::instruction *> &to_sync);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
|
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||||
|
|
||||||
|
#include "triton/codegen/target.h"
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
@@ -27,12 +28,16 @@ private:
|
|||||||
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||||
|
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
public:
|
public:
|
||||||
peephole() {}
|
peephole(target* tgt): tgt_(tgt) {}
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
|
|
||||||
|
private:
|
||||||
|
target* tgt_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
26
include/triton/codegen/transform/reorder.h
Normal file
26
include/triton/codegen/transform/reorder.h
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||||
|
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
// forward declaration
|
||||||
|
namespace ir {
|
||||||
|
class module;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace codegen{
|
||||||
|
|
||||||
|
namespace transform{
|
||||||
|
|
||||||
|
class reorder {
|
||||||
|
public:
|
||||||
|
void run(ir::module& module);
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -39,43 +39,23 @@ public:
|
|||||||
|
|
||||||
// CUDA device
|
// CUDA device
|
||||||
class cu_device: public device {
|
class cu_device: public device {
|
||||||
public:
|
|
||||||
//Supported architectures
|
|
||||||
enum class Architecture{
|
|
||||||
//NVidia
|
|
||||||
SM_2_0,
|
|
||||||
SM_2_1,
|
|
||||||
SM_3_0,
|
|
||||||
SM_3_5,
|
|
||||||
SM_3_7,
|
|
||||||
SM_5_0,
|
|
||||||
SM_5_2,
|
|
||||||
SM_6_0,
|
|
||||||
SM_6_1,
|
|
||||||
SM_7_0,
|
|
||||||
UNKNOWN
|
|
||||||
};
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//Metaprogramming elper to get cuda info from attribute
|
//Metaprogramming elper to get cuda info from attribute
|
||||||
template<CUdevice_attribute attr>
|
template<CUdevice_attribute attr>
|
||||||
int cuGetInfo() const;
|
int cuGetInfo() const;
|
||||||
|
|
||||||
inline Architecture nv_arch(std::pair<unsigned int, unsigned int> sm) const;
|
|
||||||
inline nvmlDevice_t nvml_device() const;
|
inline nvmlDevice_t nvml_device() const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){}
|
cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){}
|
||||||
// Accessors
|
|
||||||
Architecture architecture() const;
|
|
||||||
// Informations
|
// Informations
|
||||||
std::string infos() const;
|
std::string infos() const;
|
||||||
size_t address_bits() const;
|
size_t address_bits() const;
|
||||||
std::vector<size_t> max_block_dim() const;
|
std::vector<size_t> max_block_dim() const;
|
||||||
size_t warp_size() const;
|
size_t warp_size() const;
|
||||||
// Compute Capability
|
// Compute Capability
|
||||||
void interpret_as(std::pair<size_t, size_t> cc);
|
void interpret_as(int cc);
|
||||||
std::pair<size_t, size_t> compute_capability() const;
|
int compute_capability() const;
|
||||||
// Identifier
|
// Identifier
|
||||||
std::string name() const;
|
std::string name() const;
|
||||||
std::string pci_bus_id() const;
|
std::string pci_bus_id() const;
|
||||||
@@ -91,7 +71,7 @@ public:
|
|||||||
std::unique_ptr<codegen::target> make_target() const;
|
std::unique_ptr<codegen::target> make_target() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<std::pair<size_t, size_t>> interpreted_as_;
|
std::shared_ptr<int> interpreted_as_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -19,18 +19,18 @@ namespace triton
|
|||||||
namespace nvrtc
|
namespace nvrtc
|
||||||
{
|
{
|
||||||
|
|
||||||
#define ISAAC_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
|
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
|
||||||
|
|
||||||
ISAAC_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
|
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
|
||||||
ISAAC_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
|
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
|
||||||
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
|
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
|
||||||
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
|
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
|
||||||
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
|
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
|
||||||
ISAAC_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
|
TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
|
||||||
ISAAC_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
|
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
|
||||||
ISAAC_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
|
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
|
||||||
|
|
||||||
#undef ISAAC_CREATE_NVRTC_EXCEPTION
|
#undef TRITON_CREATE_NVRTC_EXCEPTION
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -38,107 +38,107 @@ namespace triton
|
|||||||
{
|
{
|
||||||
class base: public std::exception{};
|
class base: public std::exception{};
|
||||||
|
|
||||||
#define ISAAC_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
|
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
|
||||||
|
|
||||||
|
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
|
TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
|
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
|
TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
|
TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
|
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
|
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
|
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
|
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(no_device ,"no device");
|
TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
|
TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
|
TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
|
TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
|
TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
|
TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
|
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
|
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
|
TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
|
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
|
TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
|
TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
|
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
|
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
|
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
|
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
|
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
|
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
|
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
|
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
|
TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
|
TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
|
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
|
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
|
TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
|
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(not_found ,"not found");
|
TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
|
TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
|
TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
|
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
|
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
|
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
|
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
|
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
|
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
|
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
|
TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
|
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
|
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
|
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
|
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
|
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
|
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
|
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
|
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
|
TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
|
TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
|
TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
|
||||||
ISAAC_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
|
TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
|
||||||
|
|
||||||
#undef ISAAC_CREATE_CUDA_EXCEPTION
|
#undef TRITON_CREATE_CUDA_EXCEPTION
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace cublas
|
namespace cublas
|
||||||
{
|
{
|
||||||
class base: public std::exception{};
|
class base: public std::exception{};
|
||||||
|
|
||||||
#define ISAAC_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
|
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
|
||||||
|
|
||||||
ISAAC_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
|
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
|
||||||
ISAAC_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
|
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
|
||||||
ISAAC_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
|
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
|
||||||
ISAAC_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
|
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||||
ISAAC_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
|
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
|
||||||
ISAAC_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
|
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
|
||||||
ISAAC_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
|
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
|
||||||
ISAAC_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
|
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
|
||||||
ISAAC_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
|
TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
|
||||||
ISAAC_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
|
TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
|
||||||
|
|
||||||
#undef ISAAC_CREATE_CUBLAS_EXCEPTION
|
#undef TRITON_CREATE_CUBLAS_EXCEPTION
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace cudnn
|
namespace cudnn
|
||||||
{
|
{
|
||||||
#define ISAAC_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
|
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
|
||||||
|
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
|
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
|
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
|
TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
|
TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
|
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
|
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
|
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
|
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
|
TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
|
TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
|
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
|
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
|
||||||
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
|
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -44,6 +44,13 @@ public:
|
|||||||
const std::string &features,
|
const std::string &features,
|
||||||
file_type_t file_type);
|
file_type_t file_type);
|
||||||
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
|
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
|
||||||
|
std::string llir() const { return llir_; }
|
||||||
|
int spilled() const { return spilled_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string llir_;
|
||||||
|
protected:
|
||||||
|
int spilled_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// CPU
|
// CPU
|
||||||
@@ -59,12 +66,12 @@ class cu_module: public module {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
|
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
|
||||||
cu_module(const std::string& source);
|
cu_module(driver::device* device, const std::string& source);
|
||||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||||
const std::string& source() const { return source_; }
|
const std::string& ptx() const { return ptx_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string source_;
|
std::string ptx_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@@ -146,8 +146,10 @@ public:
|
|||||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||||
// Intrinsics
|
// Intrinsics
|
||||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||||
|
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
|
||||||
value *create_copy_from_shared(value *arg, const std::string &name = "");
|
value *create_copy_from_shared(value *arg, const std::string &name = "");
|
||||||
value *create_barrier(const std::string &name = "");
|
value *create_barrier(const std::string &name = "");
|
||||||
|
value *create_async_wait();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
context &ctx_;
|
context &ctx_;
|
||||||
|
@@ -7,7 +7,7 @@ namespace triton{
|
|||||||
namespace ir{
|
namespace ir{
|
||||||
|
|
||||||
|
|
||||||
enum binary_op_t {
|
enum binary_op_t: unsigned int{
|
||||||
Add,
|
Add,
|
||||||
FAdd,
|
FAdd,
|
||||||
Sub,
|
Sub,
|
||||||
@@ -28,7 +28,7 @@ enum binary_op_t {
|
|||||||
Xor
|
Xor
|
||||||
};
|
};
|
||||||
|
|
||||||
enum cast_op_t {
|
enum cast_op_t: unsigned int {
|
||||||
Trunc,
|
Trunc,
|
||||||
ZExt,
|
ZExt,
|
||||||
SExt,
|
SExt,
|
||||||
@@ -44,7 +44,7 @@ enum cast_op_t {
|
|||||||
AddrSpaceCast
|
AddrSpaceCast
|
||||||
};
|
};
|
||||||
|
|
||||||
enum cmp_pred_t {
|
enum cmp_pred_t: unsigned int {
|
||||||
FIRST_FCMP_PREDICATE,
|
FIRST_FCMP_PREDICATE,
|
||||||
FCMP_FALSE,
|
FCMP_FALSE,
|
||||||
FCMP_OEQ,
|
FCMP_OEQ,
|
||||||
@@ -113,6 +113,7 @@ enum value_id_t: unsigned {
|
|||||||
// io
|
// io
|
||||||
INST_UNMASKED_LOAD,
|
INST_UNMASKED_LOAD,
|
||||||
INST_MASKED_LOAD,
|
INST_MASKED_LOAD,
|
||||||
|
INST_MASKED_LOAD_ASYNC,
|
||||||
INST_UNMASKED_STORE,
|
INST_UNMASKED_STORE,
|
||||||
INST_MASKED_STORE,
|
INST_MASKED_STORE,
|
||||||
// retile
|
// retile
|
||||||
@@ -139,6 +140,7 @@ enum value_id_t: unsigned {
|
|||||||
INST_COPY_FROM_SHARED,
|
INST_COPY_FROM_SHARED,
|
||||||
INST_RECOALESCE,
|
INST_RECOALESCE,
|
||||||
INST_BARRIER,
|
INST_BARRIER,
|
||||||
|
INST_ASYNC_WAIT,
|
||||||
INST_MAKE_RANGE_DYN,
|
INST_MAKE_RANGE_DYN,
|
||||||
INST_MAKE_RANGE_STA,
|
INST_MAKE_RANGE_STA,
|
||||||
INST_MAKE_RANGE
|
INST_MAKE_RANGE
|
||||||
|
@@ -72,6 +72,7 @@ public:
|
|||||||
case noalias: return ".noalias";
|
case noalias: return ".noalias";
|
||||||
case aligned: return ".aligned(" + std::to_string(value_) + ")";
|
case aligned: return ".aligned(" + std::to_string(value_) + ")";
|
||||||
case multiple_of: return ".readonly";
|
case multiple_of: return ".readonly";
|
||||||
|
case retune: return ".retunr";
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
assert(false);
|
assert(false);
|
||||||
|
@@ -64,9 +64,10 @@ public:
|
|||||||
// cloning
|
// cloning
|
||||||
ir::instruction* clone() {
|
ir::instruction* clone() {
|
||||||
ir::instruction* res = clone_impl();
|
ir::instruction* res = clone_impl();
|
||||||
for(auto it = op_begin(); it != op_end(); it++)
|
// for(auto it = op_begin(); it != op_end(); it++)
|
||||||
(*it)->add_use(res);
|
// (*it)->add_use(res);
|
||||||
res->parent_ = nullptr;
|
res->parent_ = nullptr;
|
||||||
|
res->users_.clear();
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
// instruction id
|
// instruction id
|
||||||
@@ -431,6 +432,25 @@ public:
|
|||||||
_TRITON_DEFINE_ACCEPT(masked_load_inst)
|
_TRITON_DEFINE_ACCEPT(masked_load_inst)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// masked load async
|
||||||
|
class masked_load_async_inst: public load_inst {
|
||||||
|
private:
|
||||||
|
std::string repr_impl() const { return "masked_load_async_async"; }
|
||||||
|
masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||||
|
const std::string &name, instruction *next);
|
||||||
|
|
||||||
|
public:
|
||||||
|
// accessors
|
||||||
|
value *get_mask_operand() { return get_operand(1); }
|
||||||
|
value *get_false_value_operand() { return get_operand(2); }
|
||||||
|
// factory method
|
||||||
|
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
|
||||||
|
const std::string &name = "",
|
||||||
|
instruction *next = nullptr);
|
||||||
|
_TRITON_DEFINE_CLONE(masked_load_async_inst)
|
||||||
|
_TRITON_DEFINE_ACCEPT(masked_load_async_inst)
|
||||||
|
};
|
||||||
|
|
||||||
class atomic_add_inst: public io_inst {
|
class atomic_add_inst: public io_inst {
|
||||||
private:
|
private:
|
||||||
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||||
@@ -757,6 +777,7 @@ public:
|
|||||||
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
|
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class recoalesce_inst: public unary_inst{
|
class recoalesce_inst: public unary_inst{
|
||||||
private:
|
private:
|
||||||
using unary_inst::unary_inst;
|
using unary_inst::unary_inst;
|
||||||
@@ -780,6 +801,18 @@ public:
|
|||||||
instruction *next = nullptr);
|
instruction *next = nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class async_wait_inst: public instruction{
|
||||||
|
private:
|
||||||
|
async_wait_inst(context &ctx, const std::string &name, instruction *next);
|
||||||
|
std::string repr_impl() const { return "async_wait"; }
|
||||||
|
_TRITON_DEFINE_CLONE(async_wait_inst)
|
||||||
|
_TRITON_DEFINE_ACCEPT(async_wait_inst)
|
||||||
|
|
||||||
|
public:
|
||||||
|
static async_wait_inst* create(context &ctx, const std::string &name = "",
|
||||||
|
instruction *next = nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
// On NVIDIA, implementation is such that
|
// On NVIDIA, implementation is such that
|
||||||
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
||||||
// so as to enable re-association on nv_static_program_idx which is constant
|
// so as to enable re-association on nv_static_program_idx which is constant
|
||||||
|
@@ -65,7 +65,9 @@ class select_inst;
|
|||||||
class recoalesce_inst;
|
class recoalesce_inst;
|
||||||
class copy_to_shared_inst;
|
class copy_to_shared_inst;
|
||||||
class copy_from_shared_inst;
|
class copy_from_shared_inst;
|
||||||
|
class masked_load_async_inst;
|
||||||
class barrier_inst;
|
class barrier_inst;
|
||||||
|
class async_wait_inst;
|
||||||
class make_range_dyn;
|
class make_range_dyn;
|
||||||
class make_range;
|
class make_range;
|
||||||
|
|
||||||
@@ -139,7 +141,9 @@ public:
|
|||||||
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0;
|
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0;
|
||||||
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
|
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
|
||||||
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
|
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
|
||||||
|
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
||||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||||
|
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||||
virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
||||||
virtual void visit_make_range(make_range*) = 0;
|
virtual void visit_make_range(make_range*) = 0;
|
||||||
|
|
||||||
|
34
include/triton/runtime/error.h
Normal file
34
include/triton/runtime/error.h
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifndef _TRITON_RUNTIME_ERROR_H_
|
||||||
|
#define _TRITON_RUNTIME_ERROR_H_
|
||||||
|
|
||||||
|
#include <exception>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
namespace runtime{
|
||||||
|
namespace exception {
|
||||||
|
|
||||||
|
class base: public std::exception {};
|
||||||
|
#define TRITON_CREATE_RUNTIME_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "Triton: Error - Runtime: " msg; } };
|
||||||
|
|
||||||
|
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_shared_memory, "out of shared memory")
|
||||||
|
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_registers, "out of registers")
|
||||||
|
|
||||||
|
class no_valid_configuration: public exception::base {
|
||||||
|
public:
|
||||||
|
no_valid_configuration(const std::string& err): err_(err) { }
|
||||||
|
const char * what() const throw(){ return err_.c_str(); }
|
||||||
|
private:
|
||||||
|
std::string err_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#undef TRITON_CREATE_RUNTIME_EXCEPTION
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -6,6 +6,7 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <set>
|
#include <set>
|
||||||
@@ -13,6 +14,7 @@
|
|||||||
#include "triton/ir/context.h"
|
#include "triton/ir/context.h"
|
||||||
#include "triton/codegen/target.h"
|
#include "triton/codegen/target.h"
|
||||||
#include "triton/runtime/arg.h"
|
#include "triton/runtime/arg.h"
|
||||||
|
#include "triton/runtime/error.h"
|
||||||
|
|
||||||
namespace llvm {
|
namespace llvm {
|
||||||
class Module;
|
class Module;
|
||||||
@@ -56,33 +58,43 @@ template<typename T> inline T convert(const std::string& name);
|
|||||||
template<> inline long convert<long>(const std::string& name) { return std::stol(name); }
|
template<> inline long convert<long>(const std::string& name) { return std::stol(name); }
|
||||||
template<> inline int convert<int>(const std::string& name) { return std::stoi(name); }
|
template<> inline int convert<int>(const std::string& name) { return std::stoi(name); }
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
void add_arg(std::stringstream& ss, T arg) {
|
||||||
|
ss.write((char*)&arg, sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
enum asm_mode_t {
|
||||||
|
ASM_LLIR,
|
||||||
|
ASM_NV_PTX,
|
||||||
|
ASM_NV_SASS
|
||||||
|
};
|
||||||
|
|
||||||
|
struct options_space_t {
|
||||||
|
typedef std::pair<std::string, std::vector<std::string>> define_t;
|
||||||
|
std::vector<define_t> defines;
|
||||||
|
std::vector<int> num_warps;
|
||||||
|
std::vector<int> recompile_key;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct options_t {
|
||||||
|
template<class T>
|
||||||
|
T D(const std::string& name) const {
|
||||||
|
return convert<T>(defines.at(name));
|
||||||
|
}
|
||||||
|
bool operator<(const options_t& other) const {
|
||||||
|
return std::make_pair(defines, num_warps) <
|
||||||
|
std::make_pair(other.defines, other.num_warps);
|
||||||
|
}
|
||||||
|
std::string to_str() const;
|
||||||
|
|
||||||
|
std::map<std::string, std::string> defines;
|
||||||
|
size_t num_warps;
|
||||||
|
};
|
||||||
|
|
||||||
class function {
|
class function {
|
||||||
public:
|
public:
|
||||||
struct options_space_t {
|
|
||||||
typedef std::pair<std::string, std::vector<std::string>> define_t;
|
|
||||||
std::vector<define_t> defines;
|
|
||||||
std::vector<int> num_warps;
|
|
||||||
std::vector<int> recompile_key;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct options_t {
|
|
||||||
template<class T>
|
|
||||||
T D(const std::string& name) const {
|
|
||||||
return convert<T>(defines.at(name));
|
|
||||||
}
|
|
||||||
bool operator<(const options_t& other) const {
|
|
||||||
return std::make_pair(defines, num_warps) <
|
|
||||||
std::make_pair(other.defines, other.num_warps);
|
|
||||||
}
|
|
||||||
std::string to_str() const;
|
|
||||||
|
|
||||||
std::map<std::string, std::string> defines;
|
|
||||||
size_t num_warps;
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
|
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class caller {
|
class caller {
|
||||||
public:
|
public:
|
||||||
@@ -135,7 +147,7 @@ public:
|
|||||||
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device);
|
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device);
|
||||||
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device);
|
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device);
|
||||||
void set_cst(const char* name, void* data, size_t n_bytes);
|
void set_cst(const char* name, void* data, size_t n_bytes);
|
||||||
std::string ptx(driver::device *device, const options_t& opt);
|
std::string get_asm(asm_mode_t mode, driver::device *device, const options_t& opt);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<std::string, std::vector<char>> cst_;
|
std::map<std::string, std::vector<char>> cst_;
|
||||||
|
@@ -33,25 +33,20 @@ private:
|
|||||||
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
|
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
|
||||||
{
|
{
|
||||||
// const driver::device * device = stream->context()->device();
|
// const driver::device * device = stream->context()->device();
|
||||||
|
size_t warmup = 10;
|
||||||
|
size_t repeat = 50;
|
||||||
timer tmr;
|
timer tmr;
|
||||||
std::vector<size_t> times;
|
std::vector<size_t> times;
|
||||||
double total_time = 0;
|
double total_time = 0;
|
||||||
op();
|
for(size_t i = 0; i < warmup; i++)
|
||||||
|
op();
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
tmr.start();
|
tmr.start();
|
||||||
for(size_t i = 0; i < 10; i++){
|
for(size_t i = 0; i < repeat; i++){
|
||||||
// while(total_time*1e-9 < 1e-2){
|
|
||||||
// float norm = 1;
|
|
||||||
// normalize clock if possible to reduce noise in auto-tuning
|
|
||||||
// if(normalize)
|
|
||||||
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
|
||||||
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
|
||||||
op();
|
op();
|
||||||
// times.push_back(norm*tmr.get().count());
|
|
||||||
// total_time+=times.back();
|
|
||||||
}
|
}
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
return (float)tmr.get().count() / 10;
|
return (float)tmr.get().count() / repeat;
|
||||||
|
|
||||||
// return *std::min_element(times.begin(), times.end());
|
// return *std::min_element(times.begin(), times.end());
|
||||||
}
|
}
|
||||||
|
@@ -79,7 +79,7 @@ void axes::update_graph_dot(ir::instruction *i) {
|
|||||||
graph_.add_edge({dot, d}, {D, d});
|
graph_.add_edge({dot, d}, {D, d});
|
||||||
}
|
}
|
||||||
|
|
||||||
void axes::update_graph_elementwise(ir::instruction *i) {
|
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
|
||||||
if(i->get_num_operands() == 0)
|
if(i->get_num_operands() == 0)
|
||||||
return;
|
return;
|
||||||
ir::value *op = i->get_operand(0);
|
ir::value *op = i->get_operand(0);
|
||||||
@@ -89,7 +89,7 @@ void axes::update_graph_elementwise(ir::instruction *i) {
|
|||||||
for(unsigned d = 0; d < rank; d++)
|
for(unsigned d = 0; d < rank; d++)
|
||||||
for(ir::value* opx: i->ops())
|
for(ir::value* opx: i->ops())
|
||||||
for(ir::value* opy: i->ops()){
|
for(ir::value* opy: i->ops()){
|
||||||
if(!i->get_type()->is_void_ty())
|
if(connect_ret && !i->get_type()->is_void_ty())
|
||||||
graph_.add_edge({i, d}, {opx, d});
|
graph_.add_edge({i, d}, {opx, d});
|
||||||
graph_.add_edge({opx, d}, {opy, d});
|
graph_.add_edge({opx, d}, {opy, d});
|
||||||
}
|
}
|
||||||
@@ -111,7 +111,8 @@ void axes::update_graph(ir::instruction *i) {
|
|||||||
case ir::INST_TRANS: return update_graph_trans(i);
|
case ir::INST_TRANS: return update_graph_trans(i);
|
||||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||||
case ir::INST_DOT: return update_graph_dot(i);
|
case ir::INST_DOT: return update_graph_dot(i);
|
||||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
|
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||||
|
case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false);
|
||||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||||
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
|
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
|
||||||
default: return update_graph_elementwise(i);
|
default: return update_graph_elementwise(i);
|
||||||
|
@@ -55,7 +55,7 @@ inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
|||||||
for(ir::user* u: v->get_users()){
|
for(ir::user* u: v->get_users()){
|
||||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||||
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
||||||
result = v;
|
result = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,8 +115,10 @@ data_layout::data_layout(id_t id,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t data_layout::find_axis(int to_find) const {
|
int data_layout::find_axis(int to_find) const {
|
||||||
auto it = std::find(axes_.begin(), axes_.end(), to_find);
|
auto it = std::find(axes_.begin(), axes_.end(), to_find);
|
||||||
|
if(it == axes_.end())
|
||||||
|
return -1;
|
||||||
return std::distance(axes_.begin(), it);
|
return std::distance(axes_.begin(), it);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,23 +127,41 @@ size_t data_layout::find_axis(int to_find) const {
|
|||||||
* MMA Layout *
|
* MMA Layout *
|
||||||
* -------------------------------- */
|
* -------------------------------- */
|
||||||
|
|
||||||
mma884_layout::mma884_layout(size_t num_warps,
|
mma_layout::mma_layout(size_t num_warps,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const std::vector<unsigned>& shape,
|
const std::vector<unsigned>& shape,
|
||||||
const std::vector<ir::value *> &values,
|
const std::vector<ir::value *> &values,
|
||||||
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
|
analysis::align* align, target* tgt,
|
||||||
|
shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) {
|
||||||
/* fragments per warp */
|
/* fragments per warp */
|
||||||
// try to make things as square as possible to maximize data re-use
|
// try to make things as square as possible to maximize data re-use
|
||||||
fpw_ = {1, 1, 1};
|
if(tgt->as_nvidia()->sm() < 80){
|
||||||
std::vector<int> fpw_nm1;
|
fpw_ = {1, 1, 1};
|
||||||
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
std::vector<int> fpw_nm1;
|
||||||
do {
|
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||||
fpw_nm1 = fpw_;
|
do {
|
||||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
fpw_nm1 = fpw_;
|
||||||
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||||
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||||
}while(fpw_nm1 != fpw_);
|
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||||
|
}while(fpw_nm1 != fpw_);
|
||||||
|
auto ord_a = layout_a->get_order();
|
||||||
|
auto ord_b = layout_b->get_order();
|
||||||
|
bool is_a_row = ord_a[0] != 0;
|
||||||
|
bool is_b_row = ord_b[0] != 0;
|
||||||
|
bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
|
||||||
|
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
|
||||||
|
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
|
||||||
|
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
|
||||||
|
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||||
|
spw_ = {fpw_[0]*8*pack_size_0, fpw_[1]*8*pack_size_1, 1};
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
fpw_ = {1, 1, 1};
|
||||||
|
spw_ = {16, 8, 1};
|
||||||
|
rep_ = {2, 2, 1};
|
||||||
|
}
|
||||||
|
|
||||||
/* warps per tile */
|
/* warps per tile */
|
||||||
// try to make things as square as possible to maximize data re-use
|
// try to make things as square as possible to maximize data re-use
|
||||||
@@ -150,17 +170,13 @@ mma884_layout::mma884_layout(size_t num_warps,
|
|||||||
do{
|
do{
|
||||||
wpt_nm1 = wpt_;
|
wpt_nm1 = wpt_;
|
||||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
|
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
|
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||||
}while(wpt_nm1 != wpt_);
|
}while(wpt_nm1 != wpt_);
|
||||||
|
|
||||||
/* sanity check */
|
/* shape per block */
|
||||||
unsigned effective_num_warps = 1;
|
spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||||
for(size_t d = 0; d < shape.size(); d++)
|
|
||||||
effective_num_warps *= wpt_[d];
|
|
||||||
// if(num_warps != effective_num_warps)
|
|
||||||
// throw std::runtime_error("cannot create a kernel with this amount of warps");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -183,13 +199,15 @@ scanline_layout::scanline_layout(size_t num_warps,
|
|||||||
ir::value *ptr = nullptr;
|
ir::value *ptr = nullptr;
|
||||||
for(ir::value *v: values)
|
for(ir::value *v: values)
|
||||||
for(ir::user *usr: v->get_users())
|
for(ir::user *usr: v->get_users())
|
||||||
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
|
if(auto *st = dynamic_cast<ir::io_inst*>(usr))
|
||||||
ptr = st->get_pointer_operand();
|
ptr = st->get_pointer_operand();
|
||||||
|
|
||||||
unsigned i = order_[0];
|
unsigned i = order_[0];
|
||||||
int contiguous = 4;
|
int contiguous = 1;
|
||||||
if(ptr)
|
if(ptr){
|
||||||
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
|
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
|
||||||
|
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits);
|
||||||
|
}
|
||||||
|
|
||||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||||
@@ -204,14 +222,6 @@ scanline_layout::scanline_layout(size_t num_warps,
|
|||||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||||
num_threads = num_threads / mts_[i];
|
num_threads = num_threads / mts_[i];
|
||||||
}
|
}
|
||||||
/* sanity check */
|
|
||||||
unsigned effective_num_threads = 1;
|
|
||||||
for(size_t d = 0; d < shape_.size(); d++)
|
|
||||||
effective_num_threads *= mts_[d];
|
|
||||||
|
|
||||||
// std::cout <<values.size() << " " << num_warps << " " << effective_num_threads << std::endl;
|
|
||||||
// if(num_warps * 32 != effective_num_threads)
|
|
||||||
// throw std::runtime_error("cannot create a kernel with this amount of warps");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -246,9 +256,9 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
|||||||
ir::value *value_1 = phi->get_incoming_value(1);
|
ir::value *value_1 = phi->get_incoming_value(1);
|
||||||
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
|
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
|
||||||
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
|
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
|
||||||
if(!i_0 || !i_1 ||
|
if(!(i_0 && !i_1) &&
|
||||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) ||
|
!(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
|
||||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) )
|
!(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
|
||||||
return;
|
return;
|
||||||
if(is_latch_1)
|
if(is_latch_1)
|
||||||
res.reset(new double_buffer_info_t{value_0, value_1, phi});
|
res.reset(new double_buffer_info_t{value_0, value_1, phi});
|
||||||
@@ -257,7 +267,7 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
shared_layout::shared_layout(const data_layout *arg,
|
shared_layout::shared_layout(data_layout *arg,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
const std::vector<unsigned>& shape,
|
const std::vector<unsigned>& shape,
|
||||||
const std::vector<ir::value *> &values,
|
const std::vector<ir::value *> &values,
|
||||||
@@ -265,6 +275,7 @@ shared_layout::shared_layout(const data_layout *arg,
|
|||||||
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
|
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
|
||||||
|
|
||||||
size_ = 0;
|
size_ = 0;
|
||||||
|
arg_layout_ = arg;
|
||||||
|
|
||||||
// double-buffering
|
// double-buffering
|
||||||
for(ir::value *v: values)
|
for(ir::value *v: values)
|
||||||
@@ -284,36 +295,8 @@ shared_layout::shared_layout(const data_layout *arg,
|
|||||||
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
||||||
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
||||||
}
|
}
|
||||||
|
hmma_dot_a_ = hmma_dot_a;
|
||||||
|
hmma_dot_b_ = hmma_dot_b;
|
||||||
// non-mma ordering
|
|
||||||
std::vector<int> col = {0, 1};
|
|
||||||
std::vector<int> row = {1, 0};
|
|
||||||
for(size_t s = 2; s < get_rank(); s++){
|
|
||||||
col.push_back(s);
|
|
||||||
row.push_back(s);
|
|
||||||
}
|
|
||||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
|
||||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
|
||||||
if(is_nonhmma_dot_a)
|
|
||||||
order_ = is_trans(dot_a) ? row : col;
|
|
||||||
else if(is_nonhmma_dot_b)
|
|
||||||
order_ = is_trans(dot_b) ? col : row;
|
|
||||||
|
|
||||||
// padding
|
|
||||||
size_t pad = 0;
|
|
||||||
if(hmma_dot_a){
|
|
||||||
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
|
|
||||||
pad = 24 - shape_[row ? 0 : 1] % 32;
|
|
||||||
}
|
|
||||||
else if(hmma_dot_b){
|
|
||||||
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
|
|
||||||
pad = 24 - shape_[row ? 1 : 0] % 32;
|
|
||||||
}
|
|
||||||
else if(order_ != arg_order) {
|
|
||||||
pad = 4;
|
|
||||||
}
|
|
||||||
shape_[order_[0]] += pad;
|
|
||||||
|
|
||||||
// size
|
// size
|
||||||
size_ = ty_->get_primitive_size_in_bits() / 8;
|
size_ = ty_->get_primitive_size_in_bits() / 8;
|
||||||
@@ -362,6 +345,8 @@ void layouts::make_graph(ir::instruction *i) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||||
|
// if(layouts_.find(id) != layouts_.end())
|
||||||
|
// return;
|
||||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||||
auto cmp = [](ir::value* x, ir::value *y) {
|
auto cmp = [](ir::value* x, ir::value *y) {
|
||||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||||
@@ -374,19 +359,27 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
|||||||
const auto& axes = axes_->get(largest);
|
const auto& axes = axes_->get(largest);
|
||||||
const auto& shapes = largest->get_type()->get_tile_shapes();
|
const auto& shapes = largest->get_type()->get_tile_shapes();
|
||||||
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
|
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
|
||||||
return dynamic_cast<ir::copy_to_shared_inst*>(v);
|
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||||
|
dynamic_cast<ir::masked_load_async_inst*>(v);
|
||||||
});
|
});
|
||||||
// type
|
// type
|
||||||
if(it_hmma_c != values.end())
|
if(it_hmma_c != values.end()){
|
||||||
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
|
ir::instruction *dot = (ir::instruction*)*it_hmma_c;
|
||||||
|
ir::value *a = dot->get_operand(0);
|
||||||
|
ir::value *b = dot->get_operand(1);
|
||||||
|
create(groups_.at(a), values_.at(groups_.at(a)));
|
||||||
|
create(groups_.at(b), values_.at(groups_.at(b)));
|
||||||
|
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
|
||||||
|
}
|
||||||
else if(it_cts != values.end()){
|
else if(it_cts != values.end()){
|
||||||
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
|
ir::instruction *cts = (ir::instruction*)*it_cts;
|
||||||
ir::value *arg = cts->get_operand(0);
|
ir::value *arg = cts->get_operand(0);
|
||||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
||||||
}
|
}
|
||||||
else
|
else{
|
||||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void layouts::run(ir::module &mod) {
|
void layouts::run(ir::module &mod) {
|
||||||
@@ -420,7 +413,7 @@ void layouts::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
|
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
|
||||||
ir::value *val = recoalasce->get_operand(0);
|
ir::value *val = recoalasce->get_operand(0);
|
||||||
mma884_layout* in_layout = get(val)->to_mma884();
|
mma_layout* in_layout = get(val)->to_mma();
|
||||||
scanline_layout* out_layout = get(i)->to_scanline();
|
scanline_layout* out_layout = get(i)->to_scanline();
|
||||||
if(!in_layout || !out_layout)
|
if(!in_layout || !out_layout)
|
||||||
return;
|
return;
|
||||||
@@ -431,7 +424,7 @@ void layouts::run(ir::module &mod) {
|
|||||||
shape[ld] = in_shape[ld];
|
shape[ld] = in_shape[ld];
|
||||||
for(size_t k = 0; k < in_shape.size(); k++)
|
for(size_t k = 0; k < in_shape.size(); k++)
|
||||||
if(k != ld)
|
if(k != ld)
|
||||||
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
|
shape[k] = in_layout->to_mma()->spt(k);
|
||||||
// create layout
|
// create layout
|
||||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
|
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
|
||||||
tmp_[recoalasce] = id;
|
tmp_[recoalasce] = id;
|
||||||
|
54
lib/codegen/analysis/swizzle.cc
Normal file
54
lib/codegen/analysis/swizzle.cc
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
#include "triton/codegen/analysis/swizzle.h"
|
||||||
|
#include "triton/codegen/analysis/layout.h"
|
||||||
|
#include "triton/codegen/target.h"
|
||||||
|
#include "triton/ir/type.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
namespace codegen{
|
||||||
|
namespace analysis{
|
||||||
|
|
||||||
|
|
||||||
|
void swizzle::run(ir::module &) {
|
||||||
|
per_phase_.clear();
|
||||||
|
max_phase_.clear();
|
||||||
|
|
||||||
|
for(auto &x: layouts_->get_all()){
|
||||||
|
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
|
||||||
|
if(!layout)
|
||||||
|
continue;
|
||||||
|
ir::value* mma_dot_a = layout->hmma_dot_a();
|
||||||
|
ir::value* mma_dot_b = layout->hmma_dot_b();
|
||||||
|
if(!mma_dot_a && !mma_dot_b){
|
||||||
|
per_phase_[layout] = 1;
|
||||||
|
max_phase_[layout] = 1;
|
||||||
|
vec_[layout] = 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto ord = layout->get_order();
|
||||||
|
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||||
|
if(!in_layout)
|
||||||
|
continue;
|
||||||
|
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||||
|
if(tgt_->as_nvidia()->sm() < 80){
|
||||||
|
int inner = mma_dot_a ? 0 : 1;
|
||||||
|
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||||
|
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||||
|
if(mma_dot_a)
|
||||||
|
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||||
|
else
|
||||||
|
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||||
|
max_phase_[layout] = 8 / per_phase_[layout];
|
||||||
|
vec_[layout] = 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
@@ -1,325 +0,0 @@
|
|||||||
#include <numeric>
|
|
||||||
#include "triton/codegen/selection/machine_layout.h"
|
|
||||||
#include "triton/codegen/selection/machine_value.h"
|
|
||||||
#include "triton/codegen/selection/generator.h"
|
|
||||||
#include "triton/codegen/analysis/allocation.h"
|
|
||||||
#include "triton/codegen/analysis/axes.h"
|
|
||||||
#include "triton/codegen/target.h"
|
|
||||||
#include "triton/ir/instructions.h"
|
|
||||||
#include "triton/ir/type.h"
|
|
||||||
#include "llvm/IR/IRBuilder.h"
|
|
||||||
|
|
||||||
namespace triton{
|
|
||||||
namespace codegen{
|
|
||||||
|
|
||||||
using namespace llvm;
|
|
||||||
|
|
||||||
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
|
|
||||||
// function
|
|
||||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
|
||||||
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
|
|
||||||
std::vector<Type*> param_tys;
|
|
||||||
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
|
|
||||||
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
|
|
||||||
return FunctionType::get(return_ty, param_tys, false);
|
|
||||||
}
|
|
||||||
// pointer
|
|
||||||
if(ty->is_pointer_ty()){
|
|
||||||
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
|
|
||||||
unsigned addr_space = ty->get_pointer_address_space();
|
|
||||||
return PointerType::get(elt_ty, addr_space);
|
|
||||||
}
|
|
||||||
// integer
|
|
||||||
if(ty->is_integer_ty()){
|
|
||||||
unsigned bitwidth = ty->get_integer_bitwidth();
|
|
||||||
return IntegerType::get(ctx, bitwidth);
|
|
||||||
}
|
|
||||||
// primitive types
|
|
||||||
switch(ty->get_type_id()){
|
|
||||||
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
|
|
||||||
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
|
|
||||||
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
|
|
||||||
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
|
|
||||||
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
|
|
||||||
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
|
|
||||||
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
|
|
||||||
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
|
|
||||||
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
|
|
||||||
default: break;
|
|
||||||
}
|
|
||||||
// unknown type
|
|
||||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Grid construction
|
|
||||||
inline std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
|
|
||||||
size_t dim = shapes.size();
|
|
||||||
std::vector<Value*> result(dim);
|
|
||||||
for(unsigned k = 0; k < dim - 1; k++){
|
|
||||||
Constant *dim_k = builder.getInt32(shapes[order[k]]);
|
|
||||||
Value *rem = builder.CreateURem(trailing, dim_k);
|
|
||||||
trailing = builder.CreateUDiv(trailing, dim_k);
|
|
||||||
result[order[k]] = rem;
|
|
||||||
}
|
|
||||||
result[order[dim - 1]] = trailing;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline int32_t ceil(int32_t num, int32_t div){
|
|
||||||
return (num + div - 1)/div;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
|
|
||||||
Value *&sh_mem_ptr, analysis::shared_layout *layout,
|
|
||||||
std::map<ir::value *, Value *>& vmap,
|
|
||||||
std::map<ir::value *, tile *>& tmap)
|
|
||||||
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
|
|
||||||
|
|
||||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
|
||||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
|
|
||||||
// double-buffered
|
|
||||||
if(layout_->get_double_buffer()) {
|
|
||||||
BasicBlock *current = builder_->GetInsertBlock();
|
|
||||||
auto info = *layout_->get_double_buffer();
|
|
||||||
ir::phi_node *phi = info.phi;
|
|
||||||
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
|
|
||||||
if(parent->empty())
|
|
||||||
builder_->SetInsertPoint(parent);
|
|
||||||
else
|
|
||||||
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
|
|
||||||
// create pointers
|
|
||||||
ptr_ = builder_->CreatePHI(ptr_ty, 2);
|
|
||||||
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
|
|
||||||
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
|
|
||||||
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
|
|
||||||
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
|
|
||||||
builder_->SetInsertPoint(current);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
size_t offset = alloc_->offset(layout_);
|
|
||||||
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
|
|
||||||
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
tile* machine_shared_layout::create(ir::value *v) {
|
|
||||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
|
||||||
auto double_buffer = layout_->get_double_buffer();
|
|
||||||
// offset
|
|
||||||
Value *offset = nullptr;
|
|
||||||
if(double_buffer && v == double_buffer->phi)
|
|
||||||
offset = offset_;
|
|
||||||
// base pointer
|
|
||||||
Value *ptr = ptr_;
|
|
||||||
if(double_buffer && v == double_buffer->latch)
|
|
||||||
ptr = next_ptr_;
|
|
||||||
else if(double_buffer && v == double_buffer->first)
|
|
||||||
ptr = pre_ptr_;
|
|
||||||
// create tile
|
|
||||||
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
|
|
||||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
|
||||||
analysis::data_layout *layout)
|
|
||||||
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
tile *machine_distributed_layout::create(ir::value *v) {
|
|
||||||
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
|
|
||||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
|
||||||
size_t rank = shapes.size();
|
|
||||||
std::vector<distributed_axis> axes(rank);
|
|
||||||
std::vector<int> order(rank);
|
|
||||||
// compute axes
|
|
||||||
for(size_t d = 0; d < shapes.size(); d++){
|
|
||||||
if(shapes[d] > 1){
|
|
||||||
unsigned x = a_axes_->get(v, d);
|
|
||||||
axes[d] = axes_.at(x);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
axes[d].contiguous = 1;
|
|
||||||
axes[d].values = {builder_->getInt32(0)};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// compute order
|
|
||||||
std::iota(order.begin(), order.end(), 0);
|
|
||||||
auto cmp = [&](int x, int y) {
|
|
||||||
unsigned axx = a_axes_->get(v, x);
|
|
||||||
unsigned axy = a_axes_->get(v, y);
|
|
||||||
size_t posx = layout_->find_axis(axx);
|
|
||||||
size_t posy = layout_->find_axis(axy);
|
|
||||||
if(posx < rank && posy < rank)
|
|
||||||
return layout_->get_order(posx) < layout_->get_order(posy);
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
std::sort(order.begin(), order.end(), cmp);
|
|
||||||
return new distributed_tile(ty, shapes, order, axes, *builder_);
|
|
||||||
}
|
|
||||||
|
|
||||||
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
|
|
||||||
target *tgt, analysis::axes *a_axes,
|
|
||||||
std::map<unsigned, distributed_axis>& axes,
|
|
||||||
analysis::mma884_layout* layout)
|
|
||||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
|
||||||
|
|
||||||
Value *warp_size = builder_->getInt32(32);
|
|
||||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
|
||||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
|
||||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
|
||||||
|
|
||||||
const auto& shape = layout->get_shape();
|
|
||||||
if(shape.size() > 3)
|
|
||||||
throw std::runtime_error("unsupported");
|
|
||||||
bool is_batched = shape.size() >= 3;
|
|
||||||
|
|
||||||
Value *_1 = builder_->getInt32(1);
|
|
||||||
Value *_2 = builder_->getInt32(2);
|
|
||||||
Value *_3 = builder_->getInt32(3);
|
|
||||||
Value *_4 = builder_->getInt32(4);
|
|
||||||
Value *_16 = builder_->getInt32(16);
|
|
||||||
|
|
||||||
// fragments per warp
|
|
||||||
unsigned fpw_0 = layout->fpw(0);
|
|
||||||
unsigned fpw_1 = layout->fpw(1);
|
|
||||||
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
|
|
||||||
// warps per tile
|
|
||||||
unsigned wpt_0 = layout->wpt(0);
|
|
||||||
unsigned wpt_1 = layout->wpt(1);
|
|
||||||
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
|
|
||||||
// mma warp tile size
|
|
||||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
|
||||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
|
||||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
|
||||||
// mma block tile size
|
|
||||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
|
||||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
|
||||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
|
||||||
// number of repetition
|
|
||||||
unsigned num_rep_0 = shape[0] / hmma_bts_0;
|
|
||||||
unsigned num_rep_1 = shape[1] / hmma_bts_1;
|
|
||||||
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
|
|
||||||
// size of each pack (interleaving)
|
|
||||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
|
||||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
|
||||||
// number of packs (interleaving)
|
|
||||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
|
||||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
|
||||||
|
|
||||||
/* intra warp offset */
|
|
||||||
// offset of quad in pair
|
|
||||||
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
|
||||||
builder_->getInt32(fpw_0 * pack_size_0_));
|
|
||||||
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
|
||||||
builder_->getInt32(fpw_1 * pack_size_1_));
|
|
||||||
|
|
||||||
// Quad pair id
|
|
||||||
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
|
||||||
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
|
||||||
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
|
|
||||||
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
|
|
||||||
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
|
|
||||||
// Quad pair offset
|
|
||||||
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
|
|
||||||
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
|
|
||||||
|
|
||||||
/* inter warp offset */
|
|
||||||
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
|
|
||||||
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
|
|
||||||
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
|
|
||||||
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
|
|
||||||
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
|
|
||||||
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
|
|
||||||
|
|
||||||
/* offsets */
|
|
||||||
// a offset
|
|
||||||
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
|
|
||||||
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
|
|
||||||
// b offsets
|
|
||||||
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
|
|
||||||
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
|
|
||||||
|
|
||||||
// c offsets
|
|
||||||
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
|
|
||||||
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
|
|
||||||
builder_->CreateAdd(warp_offset_j, pair_b_off));
|
|
||||||
|
|
||||||
/* indices */
|
|
||||||
// i indices
|
|
||||||
std::vector<Value*> idx_i;
|
|
||||||
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
|
||||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
|
||||||
for(unsigned i = 0; i < 2; i++){
|
|
||||||
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
|
||||||
}
|
|
||||||
// j indices
|
|
||||||
std::vector<Value*> idx_j;
|
|
||||||
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
|
||||||
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
|
||||||
for(unsigned j = 0; j < 2; j++){
|
|
||||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
|
||||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
|
||||||
}
|
|
||||||
// z indices
|
|
||||||
std::vector<Value*> idx_z;
|
|
||||||
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
|
||||||
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
|
|
||||||
|
|
||||||
|
|
||||||
/* axes */
|
|
||||||
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
|
|
||||||
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
|
|
||||||
if(is_batched)
|
|
||||||
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
|
|
||||||
target *tgt,
|
|
||||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
|
|
||||||
analysis::scanline_layout* layout)
|
|
||||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
|
||||||
|
|
||||||
Value *warp_size = builder_->getInt32(32);
|
|
||||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
|
||||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
|
||||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
|
||||||
|
|
||||||
auto order = layout->get_order();
|
|
||||||
const auto& shape = layout->get_shape();
|
|
||||||
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
|
|
||||||
// Delinearize
|
|
||||||
size_t dim = shape.size();
|
|
||||||
std::vector<Value*> thread_id(dim);
|
|
||||||
for(unsigned k = 0; k < dim - 1; k++){
|
|
||||||
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
|
|
||||||
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
|
|
||||||
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
|
|
||||||
thread_id[order[k]] = rem;
|
|
||||||
}
|
|
||||||
thread_id[order[dim - 1]] = full_thread_id;
|
|
||||||
// Create axes
|
|
||||||
for(unsigned k = 0; k < dim; k++) {
|
|
||||||
int nts = layout->nts(k);
|
|
||||||
int mts = layout->mts(k);
|
|
||||||
std::string str_k = std::to_string(k);
|
|
||||||
Value *contiguous_k = builder_->getInt32(nts);
|
|
||||||
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
|
|
||||||
unsigned per_block = nts * mts;
|
|
||||||
unsigned per_thread = nts * shape[k] / per_block;
|
|
||||||
std::vector<Value*> idx_list(per_thread);
|
|
||||||
for(unsigned n = 0 ; n < per_thread; n++){
|
|
||||||
unsigned offset = n / nts * per_block + n % nts;
|
|
||||||
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
|
||||||
}
|
|
||||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,214 +0,0 @@
|
|||||||
#include <numeric>
|
|
||||||
#include <iostream>
|
|
||||||
#include "llvm/IR/IRBuilder.h"
|
|
||||||
#include "triton/codegen/selection/machine_value.h"
|
|
||||||
|
|
||||||
namespace triton{
|
|
||||||
namespace codegen{
|
|
||||||
|
|
||||||
using namespace llvm;
|
|
||||||
|
|
||||||
/* Distributed Tile */
|
|
||||||
void distributed_tile::init_indices() {
|
|
||||||
std::vector<size_t> id(axes_.size(), 0);
|
|
||||||
// build
|
|
||||||
size_t k = 0;
|
|
||||||
while(true) {
|
|
||||||
indices_t current;
|
|
||||||
for(size_t d = 0; d < id.size(); d++)
|
|
||||||
current.push_back(axes_[d].values[id[d]]);
|
|
||||||
size_t sz = indices_.size();
|
|
||||||
indices_[current] = sz;
|
|
||||||
values_[current] = nullptr;
|
|
||||||
ordered_indices_.push_back(current);
|
|
||||||
id[order_[0]]++;
|
|
||||||
while(id[order_[k]] == axes_[order_[k]].values.size()){
|
|
||||||
if(k == id.size() - 1)
|
|
||||||
return;
|
|
||||||
id[order_[k++]] = 0;
|
|
||||||
id[order_[k]]++;
|
|
||||||
}
|
|
||||||
k = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
|
|
||||||
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
|
|
||||||
init_indices();
|
|
||||||
}
|
|
||||||
|
|
||||||
void distributed_tile::set_value(indices_t idx, Value *x) {
|
|
||||||
assert(x->getType() == ty_ && "cannot set a value of different type");
|
|
||||||
Value *&result = values_[idx];
|
|
||||||
assert(!result && "value cannot be set twice");
|
|
||||||
result = x;
|
|
||||||
}
|
|
||||||
|
|
||||||
Value* distributed_tile::get_value(indices_t idx) {
|
|
||||||
Value *result = values_.at(idx);
|
|
||||||
assert(result && "value has not been set");
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned distributed_tile::get_linear_index(indices_t idx) {
|
|
||||||
return indices_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
indices_t distributed_tile::get_ordered_indices(unsigned id) {
|
|
||||||
return ordered_indices_.at(id);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
|
|
||||||
if(end < 0)
|
|
||||||
end = ordered_indices_.size() + end + 1;
|
|
||||||
for(unsigned i = start; i < end; i++)
|
|
||||||
fn(ordered_indices_[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
|
|
||||||
int rank = sizes.size();
|
|
||||||
int len = 1;
|
|
||||||
for(int s: sizes)
|
|
||||||
len *= s;
|
|
||||||
|
|
||||||
for(int i = 0; i < len; i++){
|
|
||||||
indices_t idx(rank);
|
|
||||||
int current = i;
|
|
||||||
for(int k = 0; k < rank; k++){
|
|
||||||
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
|
|
||||||
current = current / sizes[k];
|
|
||||||
}
|
|
||||||
fn(idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/* Shared Tile */
|
|
||||||
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
|
|
||||||
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
|
|
||||||
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
|
|
||||||
if(dyn_cast<Constant>(arg)){
|
|
||||||
cst = arg;
|
|
||||||
non_cst = _0;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
|
|
||||||
non_cst = arg;
|
|
||||||
cst = _0;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
|
|
||||||
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
|
|
||||||
if(cst_lhs && cst_rhs){
|
|
||||||
cst = arg;
|
|
||||||
non_cst = _0;
|
|
||||||
}
|
|
||||||
else if(cst_lhs){
|
|
||||||
cst = cst_lhs;
|
|
||||||
non_cst = bin_op->getOperand(1);
|
|
||||||
}
|
|
||||||
else if(cst_rhs){
|
|
||||||
cst = cst_rhs;
|
|
||||||
non_cst = bin_op->getOperand(0);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
non_cst = arg;
|
|
||||||
cst = _0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
|
|
||||||
non_cst_idx.clear();
|
|
||||||
cst_idx.clear();
|
|
||||||
for(Value *idx: arg_idx){
|
|
||||||
Value *non_cst, *cst;
|
|
||||||
extract_constant(idx, non_cst, cst);
|
|
||||||
non_cst_idx.push_back(non_cst);
|
|
||||||
cst_idx.push_back(cst);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
|
|
||||||
const std::vector<int>& perm, const std::vector<int>& order,
|
|
||||||
indices_t idx) {
|
|
||||||
// strides
|
|
||||||
std::vector<Value*> strides(shapes.size(), builder.getInt32(0));
|
|
||||||
strides[order[0]] = builder.getInt32(1);
|
|
||||||
for(size_t i = 1; i < idx.size(); i++)
|
|
||||||
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
|
|
||||||
// result
|
|
||||||
Value *result = builder.getInt32(0);
|
|
||||||
for(size_t i = 0; i < idx.size(); i++)
|
|
||||||
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
|
|
||||||
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
|
|
||||||
return_vector_ = false;
|
|
||||||
if(perm_.empty()){
|
|
||||||
perm_.resize(shapes.size());
|
|
||||||
std::iota(perm_.begin(), perm_.end(), 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void shared_tile::set_value(indices_t idx, Value *value) {
|
|
||||||
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
|
|
||||||
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
|
|
||||||
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
|
|
||||||
builder_.CreateStore(value, ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void shared_tile::set_vector_size(unsigned vector_size) {
|
|
||||||
vector_size_ = vector_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
void shared_tile::set_return_mode(bool return_vector){
|
|
||||||
return_vector_ = return_vector;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
Value* shared_tile::get_value(indices_t idx) {
|
|
||||||
indices_t non_cst_idx, cst_idx;
|
|
||||||
extract_constant(idx, non_cst_idx, cst_idx);
|
|
||||||
Value *&base_ptr = ptr_cache_[non_cst_idx];
|
|
||||||
unsigned vector_size = vector_size_;
|
|
||||||
Type *ty = ty_;
|
|
||||||
if(ty->isHalfTy() && (vector_size % 2 == 0)){
|
|
||||||
ty = IntegerType::get(ty->getContext(), 32);
|
|
||||||
vector_size = vector_size / 2;
|
|
||||||
}
|
|
||||||
if(base_ptr == nullptr){
|
|
||||||
// BasicBlock* store = builder_.GetInsertBlock();
|
|
||||||
// if(!non_cst_idx.empty())
|
|
||||||
// if(isa<Instruction>(non_cst_idx.front())){
|
|
||||||
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
|
|
||||||
// }
|
|
||||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
|
|
||||||
if(vector_size_ > 1){
|
|
||||||
Type *vec_ty = VectorType::get(ty, vector_size);
|
|
||||||
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
|
||||||
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
|
|
||||||
}
|
|
||||||
// builder_.SetInsertPoint(store);
|
|
||||||
}
|
|
||||||
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
|
|
||||||
Value *div = offset;
|
|
||||||
if(vector_size_ > 1)
|
|
||||||
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
|
||||||
Value *ptr = builder_.CreateGEP(base_ptr, div);
|
|
||||||
Value *result = builder_.CreateLoad(ptr);
|
|
||||||
if(return_vector_ == false && vector_size_ > 1) {
|
|
||||||
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
|
|
||||||
result = builder_.CreateExtractElement(result, rem);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
@@ -14,6 +14,12 @@ namespace triton{
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
|
||||||
// base
|
// base
|
||||||
|
|
||||||
|
|
||||||
|
nvidia_cu_target* target::as_nvidia() {
|
||||||
|
return dynamic_cast<nvidia_cu_target*>(this);
|
||||||
|
}
|
||||||
|
|
||||||
bool target::is_gpu() const {
|
bool target::is_gpu() const {
|
||||||
return is_gpu_;
|
return is_gpu_;
|
||||||
}
|
}
|
||||||
@@ -25,7 +31,7 @@ void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *m
|
|||||||
|
|
||||||
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
|
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
|
||||||
return builder.CreateCall(barrier, {});
|
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||||
@@ -45,8 +51,7 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
|
|||||||
Intrinsic::amdgcn_workgroup_id_y,
|
Intrinsic::amdgcn_workgroup_id_y,
|
||||||
Intrinsic::amdgcn_workgroup_id_z
|
Intrinsic::amdgcn_workgroup_id_z
|
||||||
};
|
};
|
||||||
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
|
Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {});
|
||||||
Value* group_id = builder.CreateCall(get_group_id, {});
|
|
||||||
return group_id;
|
return group_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,8 +104,7 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi
|
|||||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
||||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
||||||
};
|
};
|
||||||
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
|
Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {});
|
||||||
Value* cta_id = builder.CreateCall(get_cta_id, {});
|
|
||||||
return cta_id;
|
return cta_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,8 +124,7 @@ Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, un
|
|||||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
||||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
||||||
};
|
};
|
||||||
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]);
|
return builder.CreateIntrinsic(ids[ax], {}, {});
|
||||||
return builder.CreateCall(get_nctaid, {});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CPU
|
// CPU
|
||||||
|
@@ -66,7 +66,7 @@ void coalesce::run(ir::module &mod) {
|
|||||||
|
|
||||||
|
|
||||||
for(size_t id = 0; id < num_groups; id++) {
|
for(size_t id = 0; id < num_groups; id++) {
|
||||||
if(!layout_->get(id)->to_mma884())
|
if(!layout_->get(id)->to_mma())
|
||||||
continue;
|
continue;
|
||||||
// extract memory stores
|
// extract memory stores
|
||||||
const auto& values = layout_->values_of(id);
|
const auto& values = layout_->values_of(id);
|
||||||
|
@@ -28,12 +28,14 @@ inline bool is_shmem_res(ir::value* v){
|
|||||||
return true;
|
return true;
|
||||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||||
return true;
|
return true;
|
||||||
|
if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
|
||||||
|
return true;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// run pass on module
|
// run pass on module
|
||||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||||
// not an instruction
|
// not an instruction
|
||||||
if(!i) {
|
if(!i) {
|
||||||
@@ -58,8 +60,9 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
|
|||||||
// copy
|
// copy
|
||||||
builder.set_insert_point_after(i);
|
builder.set_insert_point_after(i);
|
||||||
ir::value *copy;
|
ir::value *copy;
|
||||||
if(to_shared)
|
if(to_shared){
|
||||||
copy = builder.create_copy_to_shared(x);
|
copy = builder.create_copy_to_shared(x);
|
||||||
|
}
|
||||||
else
|
else
|
||||||
copy = builder.create_copy_from_shared(x);
|
copy = builder.create_copy_from_shared(x);
|
||||||
parent->replace_uses_of_with(x, copy);
|
parent->replace_uses_of_with(x, copy);
|
||||||
|
@@ -54,7 +54,7 @@ void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
|||||||
add_reference(i, res);
|
add_reference(i, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) {
|
||||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
||||||
std::set<ir::value*> incoming;
|
std::set<ir::value*> incoming;
|
||||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||||
@@ -63,7 +63,10 @@ void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
|||||||
if(incoming.insert(inc_val).second){
|
if(incoming.insert(inc_val).second){
|
||||||
ir::basic_block *block = inc_val->get_parent();
|
ir::basic_block *block = inc_val->get_parent();
|
||||||
builder.set_insert_point(block->get_inst_list().back());
|
builder.set_insert_point(block->get_inst_list().back());
|
||||||
builder.create_barrier();
|
if(type.first)
|
||||||
|
builder.create_async_wait();
|
||||||
|
if(type.second)
|
||||||
|
builder.create_barrier();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -85,8 +88,9 @@ std::pair<membar::interval_vec_t,
|
|||||||
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
|
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
|
||||||
const interval_vec_t &written_to,
|
const interval_vec_t &written_to,
|
||||||
const interval_vec_t &read_from,
|
const interval_vec_t &read_from,
|
||||||
std::set<ir::instruction*>& insert_loc,
|
std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
|
||||||
std::set<ir::value*>& safe_war) {
|
std::set<ir::value*>& safe_war,
|
||||||
|
std::vector<ir::instruction*>& to_sync) {
|
||||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||||
interval_vec_t new_written_to = written_to;
|
interval_vec_t new_written_to = written_to;
|
||||||
interval_vec_t new_read_from = read_from;
|
interval_vec_t new_read_from = read_from;
|
||||||
@@ -95,6 +99,8 @@ std::pair<membar::interval_vec_t,
|
|||||||
interval_vec_t read, written;
|
interval_vec_t read, written;
|
||||||
get_read_intervals(i, read);
|
get_read_intervals(i, read);
|
||||||
get_written_intervals(i, written);
|
get_written_intervals(i, written);
|
||||||
|
if(written.size())
|
||||||
|
to_sync.push_back(i);
|
||||||
bool read_after_write = intersect(new_written_to, read);
|
bool read_after_write = intersect(new_written_to, read);
|
||||||
bool write_after_read = intersect(new_read_from, written);
|
bool write_after_read = intersect(new_read_from, written);
|
||||||
// double buffering
|
// double buffering
|
||||||
@@ -104,9 +110,14 @@ std::pair<membar::interval_vec_t,
|
|||||||
}
|
}
|
||||||
// record hazards
|
// record hazards
|
||||||
if(read_after_write || write_after_read) {
|
if(read_after_write || write_after_read) {
|
||||||
insert_loc.insert(i);
|
auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);};
|
||||||
|
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);};
|
||||||
|
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async);
|
||||||
|
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared);
|
||||||
|
insert_loc.insert({i, {copy_async_wait, barrier}});
|
||||||
new_written_to.clear();
|
new_written_to.clear();
|
||||||
new_read_from.clear();
|
new_read_from.clear();
|
||||||
|
to_sync.clear();
|
||||||
}
|
}
|
||||||
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
|
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
|
||||||
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
|
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
|
||||||
@@ -125,17 +136,17 @@ void membar::run(ir::module &mod) {
|
|||||||
if(!layout || !layout->get_double_buffer())
|
if(!layout || !layout->get_double_buffer())
|
||||||
continue;
|
continue;
|
||||||
for(ir::value *v: layout->get_values())
|
for(ir::value *v: layout->get_values())
|
||||||
if(v != layout->get_double_buffer()->phi)
|
if(v != layout->get_double_buffer()->phi){
|
||||||
safe_war.insert(v);
|
safe_war.insert(v);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for(ir::function *fn: mod.get_function_list()){
|
for(ir::function *fn: mod.get_function_list()){
|
||||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||||
std::map<ir::basic_block*, interval_vec_t> written_to;
|
std::map<ir::basic_block*, interval_vec_t> written_to;
|
||||||
std::map<ir::basic_block*, interval_vec_t> read_from;
|
std::map<ir::basic_block*, interval_vec_t> read_from;
|
||||||
std::set<ir::instruction*> insert_locs;
|
std::vector<ir::instruction*> to_sync;
|
||||||
|
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs;
|
||||||
size_t n_inserted_im1 = 0;
|
size_t n_inserted_im1 = 0;
|
||||||
bool done = false;
|
bool done = false;
|
||||||
do{
|
do{
|
||||||
@@ -150,7 +161,7 @@ void membar::run(ir::module &mod) {
|
|||||||
for(ir::basic_block* pred: block->get_predecessors())
|
for(ir::basic_block* pred: block->get_predecessors())
|
||||||
pred_read_from.push_back(read_from[pred]);
|
pred_read_from.push_back(read_from[pred]);
|
||||||
// apply transfer function
|
// apply transfer function
|
||||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war);
|
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync);
|
||||||
written_to[block] = result.first;
|
written_to[block] = result.first;
|
||||||
read_from[block] = result.second;
|
read_from[block] = result.second;
|
||||||
}
|
}
|
||||||
@@ -158,8 +169,9 @@ void membar::run(ir::module &mod) {
|
|||||||
done = (n_inserted_im1 == n_inserted_i);
|
done = (n_inserted_im1 == n_inserted_i);
|
||||||
n_inserted_im1 = n_inserted_i;
|
n_inserted_im1 = n_inserted_i;
|
||||||
}while(!done);
|
}while(!done);
|
||||||
for(ir::instruction* i: insert_locs)
|
for(auto x: insert_locs){
|
||||||
insert_barrier(i, builder);
|
insert_barrier(x.first, x.second, builder);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -97,6 +97,24 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
|||||||
|
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){
|
||||||
|
auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value);
|
||||||
|
if(!copy_to_shared)
|
||||||
|
return false;
|
||||||
|
ir::value *arg = copy_to_shared->get_operand(0);
|
||||||
|
ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg);
|
||||||
|
if(!ld)
|
||||||
|
return false;
|
||||||
|
builder.set_insert_point(copy_to_shared);
|
||||||
|
ir::value *ptr = ld->get_pointer_operand();
|
||||||
|
ir::value *msk = ld->get_mask_operand();
|
||||||
|
ir::value *val = ld->get_false_value_operand();
|
||||||
|
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
|
||||||
|
copy_to_shared->replace_all_uses_with(new_load);
|
||||||
|
return true;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||||
if(!x)
|
if(!x)
|
||||||
@@ -197,10 +215,12 @@ void peephole::run(ir::module &mod) {
|
|||||||
continue;
|
continue;
|
||||||
bool was_modified = false;
|
bool was_modified = false;
|
||||||
was_modified = was_modified || rewrite_mult(i, builder);
|
was_modified = was_modified || rewrite_mult(i, builder);
|
||||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||||
|
// if(tgt_->as_nvidia()->sm() >= 80)
|
||||||
|
// was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
||||||
if(was_modified)
|
if(was_modified)
|
||||||
seen.insert(i);
|
seen.insert(i);
|
||||||
}
|
}
|
||||||
|
51
lib/codegen/transform/reorder.cc
Normal file
51
lib/codegen/transform/reorder.cc
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "triton/ir/module.h"
|
||||||
|
#include "triton/ir/function.h"
|
||||||
|
#include "triton/ir/basic_block.h"
|
||||||
|
#include "triton/ir/instructions.h"
|
||||||
|
#include "triton/codegen/transform/reorder.h"
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
namespace codegen{
|
||||||
|
namespace transform{
|
||||||
|
|
||||||
|
void reorder::run(ir::module& mod){
|
||||||
|
ir::builder &builder = mod.get_builder();
|
||||||
|
std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
|
||||||
|
|
||||||
|
for(ir::function *fn: mod.get_function_list())
|
||||||
|
for(ir::basic_block *block: fn->blocks())
|
||||||
|
for(ir::instruction* i: block->get_inst_list()){
|
||||||
|
if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
|
||||||
|
ir::value* _ptr = ld->get_pointer_operand();
|
||||||
|
ir::value* _msk = ld->get_mask_operand();
|
||||||
|
ir::value* _val = ld->get_false_value_operand();
|
||||||
|
auto ptr = std::find(block->begin(), block->end(), _ptr);
|
||||||
|
auto msk = std::find(block->begin(), block->end(), _msk);
|
||||||
|
auto val = std::find(block->begin(), block->end(), _val);
|
||||||
|
if(ptr == block->end() || msk == block->end() || val == block->end())
|
||||||
|
continue;
|
||||||
|
auto it = std::find(block->begin(), block->end(), i);
|
||||||
|
int dist_ptr = std::distance(ptr, it);
|
||||||
|
int dist_msk = std::distance(msk, it);
|
||||||
|
int dist_val = std::distance(val, it);
|
||||||
|
if(dist_ptr < dist_msk && dist_ptr < dist_val)
|
||||||
|
builder.set_insert_point(++ptr);
|
||||||
|
if(dist_msk < dist_ptr && dist_msk < dist_val)
|
||||||
|
builder.set_insert_point(++msk);
|
||||||
|
if(dist_val < dist_ptr && dist_val < dist_msk)
|
||||||
|
builder.set_insert_point(++val);
|
||||||
|
ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
|
||||||
|
to_replace.push_back(std::make_pair(ld, new_ld));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for(auto& x: to_replace)
|
||||||
|
x.first->replace_all_uses_with(x.second);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -48,46 +48,6 @@ std::unique_ptr<codegen::target> host_device::make_target() const {
|
|||||||
// CUDA //
|
// CUDA //
|
||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
|
|
||||||
// architecture
|
|
||||||
cu_device::Architecture cu_device::nv_arch(std::pair<unsigned int, unsigned int> sm) const {
|
|
||||||
switch(sm.first) {
|
|
||||||
case 7:
|
|
||||||
switch(sm.second){
|
|
||||||
case 0: return Architecture::SM_7_0;
|
|
||||||
}
|
|
||||||
|
|
||||||
case 6:
|
|
||||||
switch(sm.second){
|
|
||||||
case 0: return Architecture::SM_6_0;
|
|
||||||
case 1: return Architecture::SM_6_1;
|
|
||||||
}
|
|
||||||
|
|
||||||
case 5:
|
|
||||||
switch(sm.second){
|
|
||||||
case 0: return Architecture::SM_5_0;
|
|
||||||
case 2: return Architecture::SM_5_2;
|
|
||||||
default: return Architecture::UNKNOWN;
|
|
||||||
}
|
|
||||||
|
|
||||||
case 3:
|
|
||||||
switch(sm.second){
|
|
||||||
case 0: return Architecture::SM_3_0;
|
|
||||||
case 5: return Architecture::SM_3_5;
|
|
||||||
case 7: return Architecture::SM_3_7;
|
|
||||||
default: return Architecture::UNKNOWN;
|
|
||||||
}
|
|
||||||
|
|
||||||
case 2:
|
|
||||||
switch(sm.second){
|
|
||||||
case 0: return Architecture::SM_2_0;
|
|
||||||
case 1: return Architecture::SM_2_1;
|
|
||||||
default: return Architecture::UNKNOWN;
|
|
||||||
}
|
|
||||||
|
|
||||||
default: return Architecture::UNKNOWN;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// information query
|
// information query
|
||||||
template<CUdevice_attribute attr>
|
template<CUdevice_attribute attr>
|
||||||
int cu_device::cuGetInfo() const{
|
int cu_device::cuGetInfo() const{
|
||||||
@@ -108,11 +68,6 @@ nvmlDevice_t cu_device::nvml_device() const{
|
|||||||
return map.at(key);
|
return map.at(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
// architecture
|
|
||||||
cu_device::Architecture cu_device::architecture() const{
|
|
||||||
return nv_arch(compute_capability());
|
|
||||||
}
|
|
||||||
|
|
||||||
// number of address bits
|
// number of address bits
|
||||||
size_t cu_device::address_bits() const{
|
size_t cu_device::address_bits() const{
|
||||||
return sizeof(size_t)*8;
|
return sizeof(size_t)*8;
|
||||||
@@ -133,17 +88,17 @@ std::string cu_device::pci_bus_id() const{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// force the device to be interpreted as a particular cc
|
// force the device to be interpreted as a particular cc
|
||||||
void cu_device::interpret_as(std::pair<size_t, size_t> cc){
|
void cu_device::interpret_as(int cc){
|
||||||
interpreted_as_ = std::make_shared<std::pair<size_t, size_t>>(cc);
|
interpreted_as_ = std::make_shared<int>(cc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute capability
|
// compute capability
|
||||||
std::pair<size_t, size_t> cu_device::compute_capability() const {
|
int cu_device::compute_capability() const {
|
||||||
if(interpreted_as_)
|
if(interpreted_as_)
|
||||||
return *interpreted_as_;
|
return *interpreted_as_;
|
||||||
size_t _major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
|
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
|
||||||
size_t _minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
|
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
|
||||||
return std::make_pair(_major, _minor);
|
return major*10 + minor;
|
||||||
}
|
}
|
||||||
|
|
||||||
// maximum number of threads per block
|
// maximum number of threads per block
|
||||||
@@ -218,7 +173,7 @@ std::string cu_device::infos() const{
|
|||||||
|
|
||||||
// target
|
// target
|
||||||
std::unique_ptr<codegen::target> cu_device::make_target() const {
|
std::unique_ptr<codegen::target> cu_device::make_target() const {
|
||||||
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target());
|
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target(compute_capability()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -93,6 +93,7 @@ namespace driver
|
|||||||
|
|
||||||
bool dispatch::cuinit(){
|
bool dispatch::cuinit(){
|
||||||
if(cuda_==nullptr){
|
if(cuda_==nullptr){
|
||||||
|
putenv((char*)"CUDA_CACHE_DISABLE=1");
|
||||||
std::string libcuda = tools::getenv("TRITON_LIBCUDA");
|
std::string libcuda = tools::getenv("TRITON_LIBCUDA");
|
||||||
if(libcuda.empty())
|
if(libcuda.empty())
|
||||||
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
||||||
|
@@ -20,7 +20,9 @@
|
|||||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
*/
|
*/
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <unistd.h>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <regex>
|
||||||
#include "triton/driver/module.h"
|
#include "triton/driver/module.h"
|
||||||
#include "triton/driver/context.h"
|
#include "triton/driver/context.h"
|
||||||
#include "triton/driver/error.h"
|
#include "triton/driver/error.h"
|
||||||
@@ -41,6 +43,19 @@
|
|||||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||||
#include "llvm/Transforms/Utils/Cloning.h"
|
#include "llvm/Transforms/Utils/Cloning.h"
|
||||||
|
|
||||||
|
std::string exec(const char* cmd) {
|
||||||
|
std::array<char, 128> buffer;
|
||||||
|
std::string result;
|
||||||
|
std::unique_ptr<FILE, decltype(&pclose)> pipe(popen(cmd, "r"), pclose);
|
||||||
|
if (!pipe) {
|
||||||
|
throw std::runtime_error("popen() failed!");
|
||||||
|
}
|
||||||
|
while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) {
|
||||||
|
result += buffer.data();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
namespace triton
|
namespace triton
|
||||||
{
|
{
|
||||||
namespace driver
|
namespace driver
|
||||||
@@ -63,11 +78,11 @@ void module::init_llvm() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
module::module(CUmodule mod, bool has_ownership)
|
module::module(CUmodule mod, bool has_ownership)
|
||||||
: polymorphic_resource(mod, has_ownership) {
|
: polymorphic_resource(mod, has_ownership), spilled_(0) {
|
||||||
}
|
}
|
||||||
|
|
||||||
module::module(host_module_t mod, bool has_ownership)
|
module::module(host_module_t mod, bool has_ownership)
|
||||||
: polymorphic_resource(mod, has_ownership) {
|
: polymorphic_resource(mod, has_ownership), spilled_(0) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -86,10 +101,12 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
|
|||||||
file_type_t ft) {
|
file_type_t ft) {
|
||||||
init_llvm();
|
init_llvm();
|
||||||
// // debug
|
// // debug
|
||||||
// llvm::legacy::PassManager pm;
|
llvm::legacy::PassManager pm;
|
||||||
|
std::string tmp;
|
||||||
|
// llvm::raw_string_ostream oss(llir_);
|
||||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||||
// pm.add(llvm::createVerifierPass());
|
pm.add(llvm::createVerifierPass());
|
||||||
// pm.run(*module);
|
pm.run(*module);
|
||||||
// create machine
|
// create machine
|
||||||
module->setTargetTriple(triple);
|
module->setTargetTriple(triple);
|
||||||
std::string error;
|
std::string error;
|
||||||
@@ -176,7 +193,7 @@ host_module::host_module(std::unique_ptr<llvm::Module> src): module(host_module_
|
|||||||
|
|
||||||
// create execution engine
|
// create execution engine
|
||||||
for(llvm::Function& fn: src->functions())
|
for(llvm::Function& fn: src->functions())
|
||||||
hst_->functions[fn.getName()] = &fn;
|
hst_->functions[fn.getName().str()] = &fn;
|
||||||
|
|
||||||
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
|
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
|
||||||
// auto DL = JTMB.getDefaultDataLayoutForTarget();
|
// auto DL = JTMB.getDefaultDataLayoutForTarget();
|
||||||
@@ -225,7 +242,8 @@ static std::map<int, int> vptx = {
|
|||||||
{10010, 64},
|
{10010, 64},
|
||||||
{10020, 65},
|
{10020, 65},
|
||||||
{11000, 70},
|
{11000, 70},
|
||||||
{11010, 71}
|
{11010, 71},
|
||||||
|
{11020, 72}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
|
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
|
||||||
@@ -238,9 +256,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
assert(short_ptr);
|
assert(short_ptr);
|
||||||
short_ptr->setValue(true);
|
short_ptr->setValue(true);
|
||||||
// compute capability
|
// compute capability
|
||||||
auto _cc = ((driver::cu_device*)device)->compute_capability();
|
int cc = ((driver::cu_device*)device)->compute_capability();
|
||||||
int cc = _cc.first*10 + _cc.second;
|
|
||||||
cc = std::min(cc, max_nvvm_cc);
|
|
||||||
std::string sm = "sm_" + std::to_string(cc);
|
std::string sm = "sm_" + std::to_string(cc);
|
||||||
// driver version
|
// driver version
|
||||||
int version;
|
int version;
|
||||||
@@ -251,12 +267,11 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
throw std::runtime_error("Triton requires CUDA 10+");
|
throw std::runtime_error("Triton requires CUDA 10+");
|
||||||
// PTX version
|
// PTX version
|
||||||
int ptx = vptx.at(version);
|
int ptx = vptx.at(version);
|
||||||
ptx = std::min(ptx, max_nvvm_ptx);
|
|
||||||
int ptx_major = ptx / 10;
|
int ptx_major = ptx / 10;
|
||||||
int ptx_minor = ptx % 10;
|
int ptx_minor = ptx % 10;
|
||||||
// create
|
// create
|
||||||
llvm::SmallVector<char, 0> buffer;
|
llvm::SmallVector<char, 0> buffer;
|
||||||
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "+ptx" + std::to_string(ptx), Assembly);
|
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", "sm_" + std::to_string(std::min(cc, max_nvvm_cc)), "", buffer, "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)), Assembly);
|
||||||
std::string result(buffer.begin(), buffer.end());
|
std::string result(buffer.begin(), buffer.end());
|
||||||
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
||||||
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
||||||
@@ -266,21 +281,69 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(compile_llvm_module(std::move(ll_module), device)) { }
|
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(device, compile_llvm_module(std::move(ll_module), device)) { }
|
||||||
|
|
||||||
cu_module::cu_module(std::string const & source) : module(CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
|
||||||
unsigned int errbufsize = 8096;
|
|
||||||
std::string errbuf(errbufsize, 0);
|
|
||||||
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)errbuf.data()};
|
|
||||||
try{
|
try{
|
||||||
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
// // compile ptx with ptxas
|
||||||
}catch(exception::cuda::invalid_ptx const &){
|
// char _fsrc[] = "/tmp/triton_k_XXXXXX";
|
||||||
|
// char _flog[] = "/tmp/triton_l_XXXXXX";
|
||||||
|
// int fdsrc = mkstemp(_fsrc);
|
||||||
|
// int fdlog = mkstemp(_flog);
|
||||||
|
// std::string fsrc = _fsrc;
|
||||||
|
// std::string flog = _flog;
|
||||||
|
// std::ofstream ofs(fsrc);
|
||||||
|
// ofs << source;
|
||||||
|
// ofs.close();
|
||||||
|
// std::string cmd;
|
||||||
|
// int err;
|
||||||
|
// driver::cu_device* cu_device = (driver::cu_device*)device;
|
||||||
|
// cmd = "ptxas -v --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||||
|
// err = system(cmd.c_str());
|
||||||
|
// dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
|
||||||
|
// std::ifstream file(flog);
|
||||||
|
// std::string log;
|
||||||
|
// if(file)
|
||||||
|
// while (!file.eof()) log.push_back(file.get());
|
||||||
|
// unlink(_fsrc);
|
||||||
|
// unlink(_flog);
|
||||||
|
|
||||||
|
// std::smatch match;
|
||||||
|
// std::regex expr ("\\b([0-9]+) bytes spill");
|
||||||
|
// spilled_ = 0;
|
||||||
|
// while (std::regex_search (log,match,expr)){
|
||||||
|
// spilled_ += std::stoi(match[1]);
|
||||||
|
// log = match.suffix();
|
||||||
|
// }
|
||||||
|
// std::cout << log << std::endl;
|
||||||
|
|
||||||
|
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
|
||||||
|
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
|
||||||
|
CU_JIT_LOG_VERBOSE};
|
||||||
|
unsigned int errbufsize = 8192;
|
||||||
|
unsigned int logbufsize = 8192;
|
||||||
|
char _err[errbufsize];
|
||||||
|
char _log[logbufsize];
|
||||||
|
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
|
||||||
|
dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval);
|
||||||
|
std::string err(_err);
|
||||||
|
std::string log(_log);
|
||||||
|
|
||||||
|
// std::cout << log << std::endl;
|
||||||
|
std::smatch match;
|
||||||
|
std::regex expr ("\\b([0-9]+) bytes spill");
|
||||||
|
spilled_ = 0;
|
||||||
|
while (std::regex_search(log,match,expr)){
|
||||||
|
spilled_ += std::stoi(match[1]);
|
||||||
|
log = match.suffix();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch(exception::cuda::invalid_ptx const &){
|
||||||
//#ifdef TRITON_LOG_PTX_ERROR
|
//#ifdef TRITON_LOG_PTX_ERROR
|
||||||
std::cout << source << std::endl;
|
std::cout << source << std::endl;
|
||||||
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
||||||
std::cerr << errbuf << std::endl;
|
|
||||||
// exit(1);
|
// exit(1);
|
||||||
//#endif
|
//#endif
|
||||||
throw;
|
throw;
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
#include "triton/ir/basic_block.h"
|
#include "triton/ir/basic_block.h"
|
||||||
#include "triton/ir/builder.h"
|
#include "triton/ir/builder.h"
|
||||||
#include "triton/ir/constant.h"
|
#include "triton/ir/constant.h"
|
||||||
@@ -253,6 +254,15 @@ DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
|
|||||||
|
|
||||||
value *builder::create_load(value *ptr, const std::string &name){
|
value *builder::create_load(value *ptr, const std::string &name){
|
||||||
return insert(unmasked_load_inst::create(ptr, name));
|
return insert(unmasked_load_inst::create(ptr, name));
|
||||||
|
// type *ty = ptr->get_type()->get_pointer_element_ty();
|
||||||
|
// value *mask = constant_int::get(get_int1_ty(), 1);
|
||||||
|
// value *undef = undef_value::get(ty);
|
||||||
|
// if(ptr->get_type()->is_tile_ty()){
|
||||||
|
// auto shapes = ptr->get_type()->get_tile_shapes();
|
||||||
|
// return insert(masked_load_inst::create(ptr, create_splat(mask, shapes), create_splat(undef, shapes), name));
|
||||||
|
// }
|
||||||
|
// return insert(masked_load_inst::create(ptr, mask, undef, name));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_store(value *ptr, value *val, const std::string &name){
|
value *builder::create_store(value *ptr, value *val, const std::string &name){
|
||||||
@@ -263,6 +273,7 @@ value *builder::create_masked_load(value *ptr, value *mask, value *false_value,
|
|||||||
return insert(masked_load_inst::create(ptr, mask, false_value, name));
|
return insert(masked_load_inst::create(ptr, mask, false_value, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){
|
value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){
|
||||||
return insert(masked_store_inst::create(ptr, val, mask, name));
|
return insert(masked_store_inst::create(ptr, val, mask, name));
|
||||||
}
|
}
|
||||||
@@ -348,13 +359,22 @@ value *builder::create_copy_to_shared(value *arg, const std::string &name) {
|
|||||||
return insert(copy_to_shared_inst::create(arg, name));
|
return insert(copy_to_shared_inst::create(arg, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
value *builder::create_copy_from_shared(value *arg, const std::string &name) {
|
value *builder::create_copy_from_shared(value *arg, const std::string &name) {
|
||||||
return insert(copy_from_shared_inst::create(arg, name));
|
return insert(copy_from_shared_inst::create(arg, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, const std::string &name) {
|
||||||
|
return insert(masked_load_async_inst::create(ptr, mask, false_value, name));
|
||||||
|
}
|
||||||
|
|
||||||
value *builder::create_barrier(const std::string &name) {
|
value *builder::create_barrier(const std::string &name) {
|
||||||
return insert(barrier_inst::create(ctx_, name));
|
return insert(barrier_inst::create(ctx_, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
value *builder::create_async_wait() {
|
||||||
|
return insert(async_wait_inst::create(ctx_));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -463,6 +463,20 @@ masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false
|
|||||||
return new masked_load_inst(ptr, mask, false_value, name, next);
|
return new masked_load_inst(ptr, mask, false_value, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// masked load async
|
||||||
|
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||||
|
const std::string &name, instruction *next)
|
||||||
|
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, name, next) {
|
||||||
|
set_operand(0, ptr);
|
||||||
|
set_operand(1, mask);
|
||||||
|
set_operand(2, false_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
||||||
|
const std::string &name, instruction *next) {
|
||||||
|
return new masked_load_async_inst(ptr, mask, false_value, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
// atomic add
|
// atomic add
|
||||||
|
|
||||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||||
@@ -804,6 +818,14 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
|
|||||||
return new barrier_inst(ctx, name, next);
|
return new barrier_inst(ctx, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async_wait_inst::async_wait_inst(context &ctx, const std::string &name,
|
||||||
|
instruction *next)
|
||||||
|
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next) { }
|
||||||
|
|
||||||
|
async_wait_inst* async_wait_inst::create(context &ctx, const std::string &name, instruction *next) {
|
||||||
|
return new async_wait_inst(ctx, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// nv_dynamic_program_idx
|
// nv_dynamic_program_idx
|
||||||
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||||
|
@@ -65,7 +65,12 @@ void print(module &mod, std::ostream& os) {
|
|||||||
os << get_name(ops[i], cnt++);
|
os << get_name(ops[i], cnt++);
|
||||||
os << (i < num_ops - 1?", ":"");
|
os << (i < num_ops - 1?", ":"");
|
||||||
}
|
}
|
||||||
os << ";" << std::endl;
|
os << ";";
|
||||||
|
// os << " (";
|
||||||
|
// for(ir::user* usr: inst->get_users())
|
||||||
|
// os << get_name(usr, cnt++) << ", " ;
|
||||||
|
// os << " )";
|
||||||
|
os << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
os << "}" << std::endl;
|
os << "}" << std::endl;
|
||||||
|
@@ -68,9 +68,10 @@ unsigned user::get_num_hidden() const {
|
|||||||
|
|
||||||
value::users_t::iterator user::replace_uses_of_with(value *before, value *after) {
|
value::users_t::iterator user::replace_uses_of_with(value *before, value *after) {
|
||||||
for(size_t i = 0; i < ops_.size(); i++)
|
for(size_t i = 0; i < ops_.size(); i++)
|
||||||
if(ops_[i] == before)
|
if(ops_[i] == before){
|
||||||
ops_[i] = after;
|
ops_[i] = after;
|
||||||
after->add_use(this);
|
after->add_use(this);
|
||||||
|
}
|
||||||
return before->erase_use(this);
|
return before->erase_use(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -56,10 +56,13 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
return set_ret(bld_->create_dot(lhs, rhs, _0));
|
return set_ret(bld_->create_dot(lhs, rhs, _0));
|
||||||
}
|
}
|
||||||
case Token::MASKED_DEREF: {
|
case Token::MASKED_DEREF: {
|
||||||
|
// TODO: FIXME
|
||||||
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
|
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
|
||||||
ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty());
|
ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty());
|
||||||
|
auto it = bld_->get_insert_block();
|
||||||
if(ret_ty->is_tile_ty())
|
if(ret_ty->is_tile_ty())
|
||||||
false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes());
|
false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes());
|
||||||
|
bld_->set_insert_point(it);
|
||||||
return set_ret(bld_->create_masked_load(rhs, lhs, false_value));
|
return set_ret(bld_->create_masked_load(rhs, lhs, false_value));
|
||||||
}
|
}
|
||||||
case Token::ELLIPSIS: {
|
case Token::ELLIPSIS: {
|
||||||
@@ -274,9 +277,7 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
|||||||
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
|
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
|
||||||
if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
|
if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
|
||||||
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
|
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
|
||||||
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
|
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val);
|
||||||
cond,
|
|
||||||
false_val);
|
|
||||||
ld->replace_all_uses_with(new_ld);
|
ld->replace_all_uses_with(new_ld);
|
||||||
ld->erase_from_parent();
|
ld->erase_from_parent();
|
||||||
return set_ret(new_ld);
|
return set_ret(new_ld);
|
||||||
@@ -468,10 +469,10 @@ void Generator::VisitForStmt(ForStmt *forStmt) {
|
|||||||
});
|
});
|
||||||
if(init_)
|
if(init_)
|
||||||
VisitStmt(init_);
|
VisitStmt(init_);
|
||||||
// VisitExpr(cond_);
|
VisitExpr(cond_);
|
||||||
// ir::value *cond = ret_;
|
ir::value *cond = ret_;
|
||||||
// bld_->create_cond_br(cond, loop_bb, next_bb);
|
bld_->create_cond_br(cond, loop_bb, next_bb);
|
||||||
bld_->create_br(loop_bb);
|
// bld_->create_br(loop_bb);
|
||||||
bld_->set_insert_point(loop_bb);
|
bld_->set_insert_point(loop_bb);
|
||||||
if(body_)
|
if(body_)
|
||||||
VisitStmt(body_);
|
VisitStmt(body_);
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
@@ -9,11 +9,13 @@
|
|||||||
#include "triton/codegen/analysis/allocation.h"
|
#include "triton/codegen/analysis/allocation.h"
|
||||||
#include "triton/codegen/analysis/liveness.h"
|
#include "triton/codegen/analysis/liveness.h"
|
||||||
#include "triton/codegen/analysis/align.h"
|
#include "triton/codegen/analysis/align.h"
|
||||||
|
#include "triton/codegen/analysis/swizzle.h"
|
||||||
#include "triton/codegen/transform/coalesce.h"
|
#include "triton/codegen/transform/coalesce.h"
|
||||||
#include "triton/codegen/transform/dce.h"
|
#include "triton/codegen/transform/dce.h"
|
||||||
#include "triton/codegen/transform/peephole.h"
|
#include "triton/codegen/transform/peephole.h"
|
||||||
#include "triton/codegen/transform/membar.h"
|
#include "triton/codegen/transform/membar.h"
|
||||||
#include "triton/codegen/transform/reassociate.h"
|
#include "triton/codegen/transform/reassociate.h"
|
||||||
|
#include "triton/codegen/transform/reorder.h"
|
||||||
#include "triton/codegen/transform/cts.h"
|
#include "triton/codegen/transform/cts.h"
|
||||||
#include "triton/codegen/transform/disassociate.h"
|
#include "triton/codegen/transform/disassociate.h"
|
||||||
#include "triton/codegen/selection/generator.h"
|
#include "triton/codegen/selection/generator.h"
|
||||||
@@ -29,6 +31,7 @@
|
|||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
#include "triton/ir/function.h"
|
#include "triton/ir/function.h"
|
||||||
#include "triton/ir/print.h"
|
#include "triton/ir/print.h"
|
||||||
|
#include "triton/runtime/error.h"
|
||||||
#include "triton/tools/bench.hpp"
|
#include "triton/tools/bench.hpp"
|
||||||
#include "triton/tools/sha1.hpp"
|
#include "triton/tools/sha1.hpp"
|
||||||
#include "triton/tools/sys/getenv.hpp"
|
#include "triton/tools/sys/getenv.hpp"
|
||||||
@@ -67,7 +70,7 @@ void _loop_nest(std::vector<size_t> const & ranges,
|
|||||||
/* OPTIONS */
|
/* OPTIONS */
|
||||||
/* --------------------- */
|
/* --------------------- */
|
||||||
|
|
||||||
std::string function::options_t::to_str() const{
|
std::string options_t::to_str() const{
|
||||||
std::string ret = "nw-" + std::to_string(num_warps);
|
std::string ret = "nw-" + std::to_string(num_warps);
|
||||||
for(const auto& x : defines){
|
for(const auto& x : defines){
|
||||||
ret += '-';
|
ret += '-';
|
||||||
@@ -110,41 +113,41 @@ arg_type convert(ir::type *ty) {
|
|||||||
throw std::runtime_error("unknown type");
|
throw std::runtime_error("unknown type");
|
||||||
}
|
}
|
||||||
|
|
||||||
void function::caller::write(std::ofstream &ofs) {
|
//void function::caller::write(std::ofstream &ofs) {
|
||||||
// write name
|
// // write name
|
||||||
ofs << name_ << std::endl;
|
// ofs << name_ << std::endl;
|
||||||
// write signature
|
// // write signature
|
||||||
for(size_t i = 0; i < param_tys_.size(); i++)
|
// for(size_t i = 0; i < param_tys_.size(); i++)
|
||||||
ofs << param_tys_[i] << " ";
|
// ofs << param_tys_[i] << " ";
|
||||||
ofs << std::endl;
|
// ofs << std::endl;
|
||||||
// write module
|
// // write module
|
||||||
std::string source = ((driver::cu_module*)(&*parent_))->source();
|
// std::string source = ((driver::cu_module*)(&*parent_))->ptx();
|
||||||
ofs << source;
|
// ofs << source;
|
||||||
}
|
//}
|
||||||
|
|
||||||
void function::caller::read(std::ifstream &ifs) {
|
//void function::caller::read(driver::context* ctx, std::ifstream &ifs) {
|
||||||
// read name
|
// // read name
|
||||||
std::getline(ifs, name_);
|
// std::getline(ifs, name_);
|
||||||
// read signature
|
// // read signature
|
||||||
std::string line;
|
// std::string line;
|
||||||
std::getline(ifs, line);
|
// std::getline(ifs, line);
|
||||||
std::istringstream current(line);
|
// std::istringstream current(line);
|
||||||
int param;
|
// int param;
|
||||||
param_tys_.clear();
|
// param_tys_.clear();
|
||||||
while(current >> param)
|
// while(current >> param)
|
||||||
param_tys_.push_back((arg_type)param);
|
// param_tys_.push_back((arg_type)param);
|
||||||
// read module
|
// // read module
|
||||||
std::string src((std::istreambuf_iterator<char>(ifs)),
|
// std::string src((std::istreambuf_iterator<char>(ifs)),
|
||||||
std::istreambuf_iterator<char>());
|
// std::istreambuf_iterator<char>());
|
||||||
parent_.reset(new driver::cu_module(src));
|
// parent_.reset(new driver::cu_module(ctx, src));
|
||||||
bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
|
// bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
|
||||||
|
|
||||||
}
|
//}
|
||||||
|
|
||||||
function::caller::caller(std::ifstream &ifs, const options_t& opt)
|
//function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt)
|
||||||
: opt_(opt) {
|
// : opt_(opt) {
|
||||||
read(ifs);
|
// read(ctx, ifs);
|
||||||
}
|
//}
|
||||||
|
|
||||||
function::caller::caller(ir::function *ir,
|
function::caller::caller(ir::function *ir,
|
||||||
std::shared_ptr<driver::module> parent, const options_t& opt)
|
std::shared_ptr<driver::module> parent, const options_t& opt)
|
||||||
@@ -198,20 +201,23 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
|
|||||||
// generate llvm code
|
// generate llvm code
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||||
|
// optimizations
|
||||||
|
bool cts_use_async = target->as_nvidia()->sm() >= 80;
|
||||||
// create passes
|
// create passes
|
||||||
codegen::analysis::align align;
|
codegen::analysis::align align;
|
||||||
codegen::analysis::axes axes;
|
codegen::analysis::axes axes;
|
||||||
|
codegen::transform::cts cts(cts_use_async);
|
||||||
codegen::transform::disassociate disassociate;
|
codegen::transform::disassociate disassociate;
|
||||||
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
|
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
|
||||||
codegen::analysis::liveness liveness(&layouts);
|
codegen::analysis::liveness liveness(&layouts);
|
||||||
|
codegen::analysis::swizzle swizzle(&layouts, target.get());
|
||||||
codegen::analysis::allocation allocation(&liveness);
|
codegen::analysis::allocation allocation(&liveness);
|
||||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||||
codegen::transform::dce dce;
|
codegen::transform::dce dce;
|
||||||
codegen::transform::peephole peephole;
|
codegen::transform::peephole peephole(target.get());
|
||||||
codegen::transform::reassociate reassociate;
|
codegen::transform::reassociate reassociate;
|
||||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||||
codegen::transform::cts cts;
|
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
|
||||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
|
||||||
// run passes
|
// run passes
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
disassociate.run(module);
|
disassociate.run(module);
|
||||||
@@ -233,17 +239,20 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::d
|
|||||||
}
|
}
|
||||||
peephole.run(module);
|
peephole.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
// ir::print(module, std::cout);
|
|
||||||
align.run(module);
|
align.run(module);
|
||||||
axes.run(module);
|
axes.run(module);
|
||||||
layouts.run(module);
|
layouts.run(module);
|
||||||
|
swizzle.run(module);
|
||||||
liveness.run(module);
|
liveness.run(module);
|
||||||
allocation.run(module);
|
allocation.run(module);
|
||||||
if(allocation.allocated_size() > device->max_shared_memory())
|
if(allocation.allocated_size() > device->max_shared_memory())
|
||||||
throw std::runtime_error("using too much shared memory");
|
throw exception::out_of_shared_memory();
|
||||||
barriers.run(module);
|
barriers.run(module);
|
||||||
|
// ir::print(module, std::cout);
|
||||||
isel.visit(module, *llvm);
|
isel.visit(module, *llvm);
|
||||||
std::unique_ptr<driver::module> res(driver::module::create(device, std::move(llvm)));
|
std::unique_ptr<driver::module> res(driver::module::create(device, std::move(llvm)));
|
||||||
|
if(res->spilled() > 256)
|
||||||
|
throw exception::out_of_registers();
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -265,11 +274,11 @@ void function::make(driver::device *device, options_t opt) {
|
|||||||
auto ir = make_ir(parser);
|
auto ir = make_ir(parser);
|
||||||
// triton-ir -> binary
|
// triton-ir -> binary
|
||||||
std::unique_ptr<driver::module> bin;
|
std::unique_ptr<driver::module> bin;
|
||||||
// try{
|
try{
|
||||||
bin = make_bin(*ir, device, opt);
|
bin = make_bin(*ir, device, opt);
|
||||||
// }catch(const std::runtime_error&){
|
}catch(const exception::base&){
|
||||||
// return nullptr;
|
throw;
|
||||||
// }
|
}
|
||||||
// create callable
|
// create callable
|
||||||
ir::function *tmp = ir->get_function_list()[0];
|
ir::function *tmp = ir->get_function_list()[0];
|
||||||
callers_[opt].reset(new caller(tmp, std::move(bin), opt));
|
callers_[opt].reset(new caller(tmp, std::move(bin), opt));
|
||||||
@@ -283,6 +292,7 @@ void function::precompile(driver::device* device, const options_space_t& space)
|
|||||||
for(const auto& x: space.defines)
|
for(const auto& x: space.defines)
|
||||||
ranges.push_back(x.second.size());
|
ranges.push_back(x.second.size());
|
||||||
// functor for source with given option
|
// functor for source with given option
|
||||||
|
std::map<options_t, std::string> err;
|
||||||
auto do_make = [&](std::vector<size_t> params) {
|
auto do_make = [&](std::vector<size_t> params) {
|
||||||
// compilation options
|
// compilation options
|
||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
@@ -291,20 +301,73 @@ void function::precompile(driver::device* device, const options_space_t& space)
|
|||||||
for(auto D: space.defines)
|
for(auto D: space.defines)
|
||||||
opt.defines[D.first] = D.second[params[i++]];
|
opt.defines[D.first] = D.second[params[i++]];
|
||||||
// compile
|
// compile
|
||||||
make(device, opt);
|
try{
|
||||||
|
make(device, opt);
|
||||||
|
}catch(const exception::base& e){
|
||||||
|
err[opt] = e.what();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
// multi-threaded compilation
|
// multi-threaded compilation
|
||||||
_loop_nest(ranges, do_make);
|
_loop_nest(ranges, do_make);
|
||||||
if(callers_.empty())
|
if(callers_.empty()){
|
||||||
throw std::runtime_error("could not compile kernel");
|
std::ostringstream dbg;
|
||||||
|
dbg << "Auto-Tuner could not find any valid configuration:" << std::endl;
|
||||||
|
for(auto x: err){
|
||||||
|
dbg << "[ ";
|
||||||
|
dbg << x.first.num_warps << ", ";
|
||||||
|
dbg << "{ ";
|
||||||
|
for(const auto& y: x.first.defines)
|
||||||
|
dbg << '"' << y.first << "\"= \"" << y.second << "\", ";
|
||||||
|
dbg << " } ] -> " << x.second << std::endl;
|
||||||
|
}
|
||||||
|
throw exception::no_valid_configuration(dbg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string function::ptx(driver::device* device, const options_t& opt) {
|
std::string function::get_asm(asm_mode_t mode, driver::device* device, const options_t& opt) {
|
||||||
make(device, opt);
|
make(device, opt);
|
||||||
const auto& fn = callers_.at(opt);
|
const auto& fn = callers_.at(opt);
|
||||||
if(!fn)
|
if(!fn)
|
||||||
return "";
|
return "";
|
||||||
return ((driver::cu_module*)fn->parent())->source();
|
switch(mode){
|
||||||
|
case ASM_LLIR:{
|
||||||
|
return fn->parent()->llir();
|
||||||
|
}
|
||||||
|
case ASM_NV_PTX:
|
||||||
|
case ASM_NV_SASS:{
|
||||||
|
std::string ptx = ((driver::cu_module*)fn->parent())->ptx();
|
||||||
|
// SASS
|
||||||
|
std::string input = std::tmpnam(nullptr);
|
||||||
|
std::string output = std::tmpnam(nullptr);
|
||||||
|
std::ofstream ofs(input);
|
||||||
|
ofs << ptx;
|
||||||
|
ofs.close();
|
||||||
|
if(mode == ASM_NV_PTX)
|
||||||
|
return ptx;
|
||||||
|
std::string cmd;
|
||||||
|
int err;
|
||||||
|
// compile ptx
|
||||||
|
driver::cu_device* cu_device = (driver::cu_device*)device;
|
||||||
|
cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o";
|
||||||
|
err = system(cmd.c_str());
|
||||||
|
// disassemble
|
||||||
|
cmd = "cuobjdump --dump-sass " + input + ".o >> " + output;
|
||||||
|
err = system(cmd.c_str());
|
||||||
|
std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/");
|
||||||
|
std::string to_delete = " /*";
|
||||||
|
std::ifstream ifs(output);
|
||||||
|
std::string line;
|
||||||
|
std::string sass;
|
||||||
|
while(std::getline(ifs, line))
|
||||||
|
if(!std::regex_match(line, comment))
|
||||||
|
sass += line + "\n";
|
||||||
|
return sass;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns program with best compilation options for given parameter
|
// returns program with best compilation options for given parameter
|
||||||
|
@@ -1,56 +0,0 @@
|
|||||||
import triton
|
|
||||||
import numpy as np
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
class MODE(Enum):
|
|
||||||
TF = 1
|
|
||||||
TORCH = 2
|
|
||||||
|
|
||||||
try:
|
|
||||||
import tensorflow as tf
|
|
||||||
mode = MODE.TF
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
mode = MODE.TORCH
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
C, H, W, B = 32, 1, 1, 128
|
|
||||||
|
|
||||||
x = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
|
|
||||||
gamma = np.random.uniform(-1, 1, C).astype(np.float32)
|
|
||||||
beta = np.random.uniform(-1, 1, C).astype(np.float32)
|
|
||||||
dy = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32)
|
|
||||||
|
|
||||||
if mode == MODE.TORCH:
|
|
||||||
fw_x = torch.from_numpy(x).cuda()
|
|
||||||
fw_gamma = torch.from_numpy(gamma).cuda()
|
|
||||||
fw_beta = torch.from_numpy(beta).cuda()
|
|
||||||
fw_dy = torch.from_numpy(dy).cuda()
|
|
||||||
# register gradients
|
|
||||||
fw_x.requires_grad_(True)
|
|
||||||
fw_gamma.requires_grad_(True)
|
|
||||||
fw_beta.requires_grad_(True)
|
|
||||||
# execute
|
|
||||||
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
|
|
||||||
fw_y.backward(fw_dy)
|
|
||||||
|
|
||||||
if mode == MODE.TF:
|
|
||||||
fw_x = tf.placeholder(shape=x.shape, dtype=x.dtype)
|
|
||||||
fw_gamma = tf.placeholder(shape=gamma.shape, dtype=gamma.dtype)
|
|
||||||
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
|
|
||||||
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
|
|
||||||
# execute
|
|
||||||
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
|
|
||||||
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
|
|
||||||
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
|
|
||||||
# run
|
|
||||||
sess = tf.InteractiveSession()
|
|
||||||
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
|
|
||||||
sess.run(tf.global_variables_initializer())
|
|
||||||
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
|
|
||||||
print(result)
|
|
@@ -1,213 +0,0 @@
|
|||||||
import triton
|
|
||||||
import torch
|
|
||||||
from torch.utils.cpp_extension import load
|
|
||||||
import numpy as np
|
|
||||||
#import utils
|
|
||||||
from time import time
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
|
|
||||||
#torch.backends.cudnn.benchmark = True
|
|
||||||
|
|
||||||
configs = []
|
|
||||||
|
|
||||||
# Matrix multiplication
|
|
||||||
MNK = [
|
|
||||||
(512, 512 ,512),
|
|
||||||
(2048, 2048, 2048),
|
|
||||||
#(8192, 8192, 8192),
|
|
||||||
|
|
||||||
(64, 64, 64000),
|
|
||||||
(64, 64, 128000),
|
|
||||||
(256, 256, 64000),
|
|
||||||
(256, 256, 128000),
|
|
||||||
|
|
||||||
(1536, 16, 1536),
|
|
||||||
(1536, 32, 1536),
|
|
||||||
(1536, 64, 1536),
|
|
||||||
# (1536, 128, 1536),
|
|
||||||
# (4096, 16, 4096),
|
|
||||||
# (4096, 32, 4096),
|
|
||||||
# (4096, 64, 4096),
|
|
||||||
# (4096, 128, 4096),
|
|
||||||
|
|
||||||
# (127008, 768, 576)
|
|
||||||
]
|
|
||||||
for M, N, K in MNK:
|
|
||||||
matmul = lambda a, b: torch.matmul(a, b)
|
|
||||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict(), None, None, None)]
|
|
||||||
#for M, N, K in MNK:
|
|
||||||
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
|
||||||
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict(), None, None, None)]
|
|
||||||
#for M, N, K in MNK:
|
|
||||||
# matmul = lambda a, b: torch.matmul(a, b.t())
|
|
||||||
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict(), None, None, None)]
|
|
||||||
|
|
||||||
# Relative attention
|
|
||||||
NTHSE = [
|
|
||||||
(16, 512, 1, 64, 64),
|
|
||||||
# (16, 512, 1, 128, 128),
|
|
||||||
# (16, 512, 1, 256, 256),
|
|
||||||
# (16, 512, 1, 256, 512),
|
|
||||||
(16, 512, 8, 64, 64),
|
|
||||||
# (16, 512, 8, 128, 128),
|
|
||||||
# (16, 512, 8, 256, 256),
|
|
||||||
# (16, 512, 8, 256, 512),
|
|
||||||
|
|
||||||
# (64, 1024, 1, 64, 64),
|
|
||||||
(64, 1024, 1, 128, 128),
|
|
||||||
# (64, 1024, 1, 256, 256),
|
|
||||||
# (64, 1024, 1, 256, 512),
|
|
||||||
# (64, 1024, 8, 64, 64),
|
|
||||||
(64, 1024, 8, 128, 128),
|
|
||||||
# (64, 1024, 8, 256, 256),
|
|
||||||
# (64, 1024, 8, 256, 512),
|
|
||||||
|
|
||||||
# (128, 1024, 1, 64, 64),
|
|
||||||
# (128, 1024, 1, 128, 128),
|
|
||||||
# (128, 1024, 1, 256, 256),
|
|
||||||
(128, 1024, 1, 256, 512),
|
|
||||||
# (128, 1024, 8, 64, 64),
|
|
||||||
# (128, 1024, 8, 128, 128),
|
|
||||||
# (128, 1024, 8, 256, 256),
|
|
||||||
#(128, 1024, 8, 256, 512)
|
|
||||||
]
|
|
||||||
#for N, T, H, S, E in NTHSE:
|
|
||||||
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict(), None, None, None)]
|
|
||||||
#for N, T, H, S, E in NTHSE:
|
|
||||||
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict(), None, None, None)]
|
|
||||||
#for N, T, H, S, E in NTHSE:
|
|
||||||
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict(), None, None, None)]
|
|
||||||
|
|
||||||
# 1D Dense convolution
|
|
||||||
NCHKR = [
|
|
||||||
#(1, 1152, 12602, 512, 3)
|
|
||||||
]
|
|
||||||
for N, C, H, K, R in NCHKR:
|
|
||||||
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
|
|
||||||
configs += [([N, C, H],
|
|
||||||
[C, R, K],
|
|
||||||
[N, K, H - R + 1],
|
|
||||||
torch_fn,
|
|
||||||
'nc(h+r),crk->nkh',
|
|
||||||
dict(), None, None, None)]
|
|
||||||
|
|
||||||
# 2D Dense convolution
|
|
||||||
NCHWKRS = [
|
|
||||||
#(8, 64, 128, 128, 768, 3, 3),
|
|
||||||
#(128, 3, 32, 32, 64, 3, 3),
|
|
||||||
#(1, 1024, 32, 112, 112, 1024, 3, 3),
|
|
||||||
#(8, 512, 32, 32, 1024, 3, 3)
|
|
||||||
]
|
|
||||||
for N, C, G, H, W, K, R, S in NCHWKRS:
|
|
||||||
stride = 2
|
|
||||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), stride=stride, groups=G)
|
|
||||||
P = (H - R + 1) // stride
|
|
||||||
Q = (W - S + 1) // stride
|
|
||||||
transform_a = lambda a: a.view(N, G, C // G, H, W)
|
|
||||||
transform_b = lambda b: b.view(C // G, R, S, G, K // G)
|
|
||||||
transform_c = lambda c: c.view(N, K, P, Q)
|
|
||||||
configs += [([N, C, H, W],
|
|
||||||
[C // G, R, S, K],
|
|
||||||
[N, G, K // G, P, Q],
|
|
||||||
torch_fn,
|
|
||||||
'ngc(h*2+r)(w*2+s),crsgk->ngkhw',
|
|
||||||
dict(), transform_a, transform_b, transform_c)]
|
|
||||||
|
|
||||||
|
|
||||||
# 3D Dense Convolution
|
|
||||||
NCDHWKTRS = [
|
|
||||||
#(8, 32, 27, 100, 100, 64, 3, 3, 3),
|
|
||||||
#(8, 64, 23, 48, 48, 256, 3, 3, 3),
|
|
||||||
#(8, 256, 19, 22, 22, 640, 3, 3, 3),
|
|
||||||
#(8, 640, 15, 36, 36, 384, 3, 3, 3)
|
|
||||||
]
|
|
||||||
for N, C, D, H, W, K, T, R, S in NCDHWKTRS:
|
|
||||||
torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3))
|
|
||||||
configs += [([N, C, D, H, W],
|
|
||||||
[C, T, R, S, K],
|
|
||||||
[N, K, D - T + 1, H - R + 1, W - R + 1],
|
|
||||||
torch_fn,
|
|
||||||
'nc(d+t)(h+r)(w+s),ctrsk->nkdhw',
|
|
||||||
dict(), None, None, None)]
|
|
||||||
|
|
||||||
|
|
||||||
# Shift convolution
|
|
||||||
shift_cuda = torch.utils.cpp_extension.load(
|
|
||||||
'shift_cuda', ['kernels/shift_cuda.cpp',
|
|
||||||
'kernels/shift_cuda_kernel.cu'],
|
|
||||||
extra_cflags=['-O3'])
|
|
||||||
class shift(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x, shift):
|
|
||||||
ctx.save_for_backward(shift)
|
|
||||||
return shift_cuda.forward(x, shift)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
shift, = ctx.saved_tensors
|
|
||||||
grad_output = shift_cuda.backward(grad_output, shift)
|
|
||||||
|
|
||||||
return grad_output, None
|
|
||||||
|
|
||||||
NCHWKRS = [
|
|
||||||
#(8, 64, 128, 128, 128, 3, 3),
|
|
||||||
#(8, 128, 64, 64, 256, 3, 3),
|
|
||||||
#(8, 256, 32, 32, 512, 3, 3),
|
|
||||||
#(8, 512, 32, 32, 1024, 3, 3)
|
|
||||||
]
|
|
||||||
for N, C, H, W, K, R, S in NCHWKRS:
|
|
||||||
shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2
|
|
||||||
shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2
|
|
||||||
def shift_conv(a, b, **kwargs):
|
|
||||||
shift_h, shift_w = kwargs['sh'], kwargs['sw']
|
|
||||||
shift_torch = np.column_stack((shift_w*-1, shift_h*-1))
|
|
||||||
shift_torch = torch.from_numpy(shift_torch).cuda()
|
|
||||||
a = shift.apply(a, shift_torch)
|
|
||||||
b = b.permute(1, 0)
|
|
||||||
b = b.reshape(b.shape[0], b.shape[1], 1, 1)
|
|
||||||
return torch.nn.functional.conv2d(a, b)
|
|
||||||
configs += [([N, C, H, W],
|
|
||||||
[C, K],
|
|
||||||
[N, K, H, W],
|
|
||||||
shift_conv,
|
|
||||||
'nc(h + sh[c])(w + sw[c]),ck->nkhw',
|
|
||||||
{'sh': shift_h, 'sw': shift_w},
|
|
||||||
None, None, None)]
|
|
||||||
|
|
||||||
# Benchmark
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays, \
|
|
||||||
transform_a, transform_b, transform_c in configs:
|
|
||||||
dtype = torch.cuda.FloatTensor
|
|
||||||
# initialize input tensors
|
|
||||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
|
||||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
|
||||||
# reference output
|
|
||||||
if torch_fn:
|
|
||||||
rc = torch_fn(a, b, **arrays)
|
|
||||||
else:
|
|
||||||
rc = torch.einsum(expr, a, b)
|
|
||||||
# triton output
|
|
||||||
ta = a if transform_a is None else transform_a(a)
|
|
||||||
tb = b if transform_b is None else transform_b(b)
|
|
||||||
tc = torch.empty(c_shape, device=a.device)
|
|
||||||
triton.ops.einsum(expr, ta, tb, tc, arrays = arrays, bench = True)
|
|
||||||
ctx = triton.ops._einsum.registry[tc]
|
|
||||||
tc = tc if transform_c is None else transform_c(tc)
|
|
||||||
# performance relative to equivalent matrix multiplication
|
|
||||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
|
||||||
cmp_eqbmm = True
|
|
||||||
if cmp_eqbmm:
|
|
||||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
|
||||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
|
||||||
c = torch.empty((B, M, N), device=a.device).cuda()
|
|
||||||
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, c, bench = True)
|
|
||||||
ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms
|
|
||||||
cmp_str = f'({ratio:4.2f})'
|
|
||||||
else:
|
|
||||||
cmp_str = ''
|
|
||||||
# test and benchmark
|
|
||||||
bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3
|
|
||||||
diff = (tc - rc).abs().max() / rc.abs().max()
|
|
||||||
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')
|
|
@@ -1,42 +0,0 @@
|
|||||||
#include <torch/torch.h>
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
// CUDA forward declarations
|
|
||||||
|
|
||||||
at::Tensor shift_cuda_forward(
|
|
||||||
const at::Tensor input,
|
|
||||||
const at::Tensor shift);
|
|
||||||
|
|
||||||
at::Tensor shift_cuda_backward(
|
|
||||||
const at::Tensor grad_input,
|
|
||||||
const at::Tensor shift);
|
|
||||||
|
|
||||||
// C++ interface
|
|
||||||
|
|
||||||
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
|
|
||||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
|
||||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
|
||||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
|
||||||
|
|
||||||
at::Tensor shift_forward(
|
|
||||||
const at::Tensor input,
|
|
||||||
const at::Tensor shift) {
|
|
||||||
CHECK_INPUT(input);
|
|
||||||
CHECK_INPUT(shift);
|
|
||||||
|
|
||||||
return shift_cuda_forward(input, shift);
|
|
||||||
}
|
|
||||||
|
|
||||||
at::Tensor shift_backward(
|
|
||||||
const at::Tensor grad_input,
|
|
||||||
const at::Tensor shift) {
|
|
||||||
CHECK_INPUT(grad_input);
|
|
||||||
CHECK_INPUT(shift);
|
|
||||||
return shift_cuda_backward(grad_input, shift);
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("forward", &shift_forward, "Shift forward (CUDA)");
|
|
||||||
m.def("backward", &shift_backward, "Shift backward (CUDA)");
|
|
||||||
}
|
|
@@ -1,111 +0,0 @@
|
|||||||
#include <ATen/ATen.h>
|
|
||||||
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
template <typename scalar_t>
|
|
||||||
__global__ void shift_cuda_forward_kernel(
|
|
||||||
const scalar_t* __restrict__ input,
|
|
||||||
const int32_t* __restrict__ shift,
|
|
||||||
scalar_t* __restrict__ output,
|
|
||||||
const int32_t B,
|
|
||||||
const int32_t C,
|
|
||||||
const int32_t H,
|
|
||||||
const int32_t W) {
|
|
||||||
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
const int32_t size = B*C*H*W;
|
|
||||||
|
|
||||||
const int32_t CHW = C*H*W;
|
|
||||||
const int32_t HW = H*W;
|
|
||||||
const int32_t b = idx / CHW;
|
|
||||||
const int32_t c = (idx - b*CHW) / HW;
|
|
||||||
const int32_t h = (idx - b*CHW - c*HW) / W;
|
|
||||||
const int32_t w = idx - b*CHW - c*HW - h*W;
|
|
||||||
const int32_t target_w = w + shift[2*c];
|
|
||||||
const int32_t target_h = h + shift[2*c + 1];
|
|
||||||
const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w;
|
|
||||||
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
|
|
||||||
output[target_idx] = input[idx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__global__ void shift_cuda_backward_kernel(
|
|
||||||
const scalar_t* __restrict__ grad_input,
|
|
||||||
scalar_t* __restrict__ grad_output,
|
|
||||||
const int32_t* __restrict__ shift,
|
|
||||||
const int32_t B,
|
|
||||||
const int32_t C,
|
|
||||||
const int32_t W,
|
|
||||||
const int32_t H) {
|
|
||||||
const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
const int32_t size = B*C*W*H;
|
|
||||||
const int32_t CWH = C*W*H;
|
|
||||||
const int32_t WH = W*H;
|
|
||||||
const int32_t b = idx / CWH;
|
|
||||||
const int32_t c = (idx - b*CWH) / WH;
|
|
||||||
const int32_t w = (idx - b*CWH - c*WH) / W;
|
|
||||||
const int32_t h = idx - b*CWH - c*WH - w*H;
|
|
||||||
const int32_t target_w = w - shift[2*c];
|
|
||||||
const int32_t target_h = h - shift[2*c + 1];
|
|
||||||
const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h;
|
|
||||||
if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) {
|
|
||||||
grad_output[target_idx] = grad_input[idx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
at::Tensor shift_cuda_forward(
|
|
||||||
const at::Tensor input,
|
|
||||||
const at::Tensor shift) {
|
|
||||||
const auto B = input.size(0);
|
|
||||||
const auto C = input.size(1);
|
|
||||||
const auto H = input.size(2);
|
|
||||||
const auto W = input.size(3);
|
|
||||||
const auto size = B*C*W*H;
|
|
||||||
const int threads = 1024;
|
|
||||||
const int blocks = (size + threads - 1) / threads;
|
|
||||||
auto output = at::zeros_like(input);
|
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] {
|
|
||||||
shift_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
|
|
||||||
input.data<scalar_t>(),
|
|
||||||
shift.data<int32_t>(),
|
|
||||||
output.data<scalar_t>(),
|
|
||||||
B,
|
|
||||||
C,
|
|
||||||
H,
|
|
||||||
W);
|
|
||||||
}));
|
|
||||||
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
at::Tensor shift_cuda_backward(
|
|
||||||
const at::Tensor grad_input,
|
|
||||||
const at::Tensor shift) {
|
|
||||||
const auto B = grad_input.size(0);
|
|
||||||
const auto C = grad_input.size(1);
|
|
||||||
const auto H = grad_input.size(2);
|
|
||||||
const auto W = grad_input.size(3);
|
|
||||||
const auto size = B*C*W*H;
|
|
||||||
const int threads = 1024;
|
|
||||||
const int blocks = (size + threads - 1) / threads;
|
|
||||||
auto grad_output = at::zeros_like(grad_input);
|
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] {
|
|
||||||
shift_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
|
|
||||||
grad_input.data<scalar_t>(),
|
|
||||||
grad_output.data<scalar_t>(),
|
|
||||||
shift.data<int32_t>(),
|
|
||||||
B,
|
|
||||||
C,
|
|
||||||
H,
|
|
||||||
W);
|
|
||||||
}));
|
|
||||||
|
|
||||||
return grad_output;
|
|
||||||
}
|
|
@@ -1,109 +0,0 @@
|
|||||||
import triton
|
|
||||||
import numpy
|
|
||||||
import torch
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
numpy.random.seed(0)
|
|
||||||
|
|
||||||
def to_sparse(expr, data, layout, shape, block):
|
|
||||||
# shape of result
|
|
||||||
sparse = None
|
|
||||||
shape_ret = []
|
|
||||||
for i, d in enumerate(expr):
|
|
||||||
if d.isupper() and sparse is None:
|
|
||||||
sparse = i
|
|
||||||
shape_ret.append(int(layout.sum()))
|
|
||||||
if d.isupper():
|
|
||||||
shape_ret.append(block[d])
|
|
||||||
else:
|
|
||||||
shape_ret.append(shape[i])
|
|
||||||
# iterator
|
|
||||||
steps = [block[d] if d.isupper() else 1 for d in expr]
|
|
||||||
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
|
|
||||||
# create result
|
|
||||||
ret = torch.empty(*shape_ret, dtype=data.dtype, device=data.device)
|
|
||||||
blockid = 0
|
|
||||||
nzblockid = 0
|
|
||||||
for curr in itertools.product(*it):
|
|
||||||
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
|
|
||||||
blockid = 0
|
|
||||||
nzblockid = 0
|
|
||||||
data_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
|
|
||||||
ret_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
|
|
||||||
ret_slice.insert(sparse, nzblockid)
|
|
||||||
if int(layout.view(-1)[blockid]):
|
|
||||||
ret[ret_slice] = data[data_slice]
|
|
||||||
nzblockid += 1
|
|
||||||
blockid += 1
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def to_dense(expr, data, layout, shape, block):
|
|
||||||
sparse = None
|
|
||||||
for i, d in enumerate(expr):
|
|
||||||
if d.isupper() and sparse is None:
|
|
||||||
sparse = i
|
|
||||||
|
|
||||||
ret = torch.zeros(*shape, dtype=data.dtype, device=data.device)
|
|
||||||
steps = [block[d] if d.isupper() else 1 for d in expr]
|
|
||||||
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
|
|
||||||
blockid = 0
|
|
||||||
nzblockid = 0
|
|
||||||
for curr in itertools.product(*it):
|
|
||||||
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
|
|
||||||
blockid = 0
|
|
||||||
nzblockid = 0
|
|
||||||
ret_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
|
|
||||||
data_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
|
|
||||||
data_slice.insert(sparse, nzblockid)
|
|
||||||
if int(layout.view(-1)[blockid]):
|
|
||||||
ret[ret_slice] = data[data_slice]
|
|
||||||
nzblockid += 1
|
|
||||||
blockid += 1
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def test_expr(expr, shape, blocks):
|
|
||||||
# decompose expr
|
|
||||||
expr_a, expr_bc = expr.split(",")
|
|
||||||
expr_b, expr_c = expr_bc.split("->")
|
|
||||||
# check with argument is sparse
|
|
||||||
sparse_a = any(x.isupper() for x in expr_a)
|
|
||||||
sparse_b = any(x.isupper() for x in expr_b)
|
|
||||||
sparse_c = any(x.isupper() for x in expr_c)
|
|
||||||
# allocate data
|
|
||||||
shape_a = [shape[d.lower()] for d in expr_a]
|
|
||||||
shape_b = [shape[d.lower()] for d in expr_b]
|
|
||||||
shape_c = [shape[d.lower()] for d in expr_c]
|
|
||||||
ref_a = torch.rand(*shape_a, device='cuda')
|
|
||||||
ref_b = torch.rand(*shape_b, device='cuda')
|
|
||||||
ref_c = torch.zeros(*shape_c, device='cuda')
|
|
||||||
# layouts
|
|
||||||
layout_a = [shape[d.lower()]//blocks[d] for d in expr_a if d.isupper()]
|
|
||||||
layout_b = [shape[d.lower()]//blocks[d] for d in expr_b if d.isupper()]
|
|
||||||
layout_c = [shape[d.lower()]//blocks[d] for d in expr_c if d.isupper()]
|
|
||||||
layout_a = torch.randint(0, 2, layout_a, device='cuda')
|
|
||||||
layout_b = torch.randint(0, 2, layout_b, device='cuda')
|
|
||||||
layout_c = torch.randint(0, 2, layout_c, device='cuda')
|
|
||||||
# triton computation
|
|
||||||
triton_a = to_sparse(expr_a, ref_a, layout_a, shape_a, blocks) if sparse_a else ref_a
|
|
||||||
triton_b = to_sparse(expr_b, ref_b, layout_b, shape_b, blocks) if sparse_b else ref_b
|
|
||||||
layouts = {expr_a: layout_a, expr_b: layout_b, expr_c: layout_c}
|
|
||||||
triton_c = triton.ops.einsum(expr, triton_a, triton_b, layouts, blocks)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
# reference computation
|
|
||||||
ref_a = to_dense(expr_a, triton_a, layout_a, shape_a, blocks) if sparse_a else ref_a
|
|
||||||
ref_b = to_dense(expr_b, triton_b, layout_b, shape_b, blocks) if sparse_b else ref_b
|
|
||||||
ref_c = torch.einsum(expr.lower(), ref_a, ref_b)
|
|
||||||
if sparse_c:
|
|
||||||
ref_c = to_sparse(expr_c, ref_c, layout_c, shape_c, blocks)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
print((ref_c - triton_c).abs().max())
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# shape characteristics
|
|
||||||
test_expr('bHMK,bhkn->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'K': 32})
|
|
||||||
test_expr('bhmk,bHKN->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'K': 32, 'N': 32})
|
|
||||||
test_expr('bhmk,bhkn->bHMN', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'N': 32})
|
|
||||||
|
|
@@ -171,7 +171,7 @@ class _conv(torch.autograd.Function):
|
|||||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
|
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
|
||||||
delta, kernel = _conv.kernel[dtype]
|
delta, kernel = _conv.kernel[dtype]
|
||||||
# allocate output
|
# allocate output
|
||||||
c = triton.empty([Z, CO, P, Q], dtype=dtype)
|
c = torch.empty([Z, CO, P, Q], dtype=dtype)
|
||||||
# enqueue
|
# enqueue
|
||||||
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
|
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
|
||||||
triton.cdiv(CO, opt.d('TN'))]
|
triton.cdiv(CO, opt.d('TN'))]
|
||||||
|
@@ -3,6 +3,9 @@ import triton
|
|||||||
|
|
||||||
class _dot(torch.autograd.Function):
|
class _dot(torch.autograd.Function):
|
||||||
src = """
|
src = """
|
||||||
|
#define STM 4
|
||||||
|
#define STN 4
|
||||||
|
|
||||||
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||||
TYPE * B __noalias __readonly __aligned(16),
|
TYPE * B __noalias __readonly __aligned(16),
|
||||||
TYPE * C __noalias __aligned(16),
|
TYPE * C __noalias __aligned(16),
|
||||||
@@ -14,20 +17,26 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
int ldb __multipleof(8),
|
int ldb __multipleof(8),
|
||||||
int ldc __multipleof(8)) {
|
int ldc __multipleof(8)) {
|
||||||
// prologue
|
// prologue
|
||||||
int ridx = get_program_id(0);
|
int pid = get_program_id(0);
|
||||||
int ridy = get_program_id(1);
|
int pidz = get_program_id(2);
|
||||||
int ridz = get_program_id(2);
|
int gridm = M / TM;
|
||||||
int gridx = M / TM;
|
int gridn = N / TN;
|
||||||
int gridy = N / TN;
|
int stgridm = (gridm + STM - 1) / STM;
|
||||||
int rid = ridx + ridy * gridx;
|
int stgridn = (gridn + STN - 1) / STN;
|
||||||
ridx = rid / gridy;
|
int stid = pid / (STM * STN);
|
||||||
ridy = rid % gridy;
|
int laneid = pid % (STM * STN);
|
||||||
int rm[TM] = ridx * TM + 0 ... TM;
|
int stm = stid / stgridn;
|
||||||
int rn[TN] = ridy * TN + 0 ... TN;
|
int stn = stid % stgridn;
|
||||||
|
int lanem = laneid / STN;
|
||||||
|
int lanen = laneid % STN;
|
||||||
|
int pidm = stm*STM + lanem;
|
||||||
|
int pidn = stn*STN + lanen;
|
||||||
|
int rm[TM] = pidm * TM + 0 ... TM;
|
||||||
|
int rn[TN] = pidn * TN + 0 ... TN;
|
||||||
|
|
||||||
// reduction splitting
|
// reduction splitting
|
||||||
K = K / TZ;
|
K = K / TZ;
|
||||||
int rk[TK] = ridz * K + 0 ... TK;
|
int rk[TK] = pidz * K + 0 ... TK;
|
||||||
|
|
||||||
// pointers to operands
|
// pointers to operands
|
||||||
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||||
@@ -44,11 +53,11 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
// reduction loop
|
// reduction loop
|
||||||
float acc[TM, TN] = 0;
|
float acc[TM, TN] = 0;
|
||||||
for(int k = K; k > 0; k -= TK){
|
for(int k = K; k > 0; k -= TK){
|
||||||
acc += a @ b;
|
|
||||||
bool checka[TM, TK] = k > TK;
|
bool checka[TM, TK] = k > TK;
|
||||||
bool checkb[TK, TN] = k > TK;
|
bool checkb[TK, TN] = k > TK;
|
||||||
pa += TK * STRIDE_AK;
|
pa += TK * STRIDE_AK;
|
||||||
pb += TK * STRIDE_BK;
|
pb += TK * STRIDE_BK;
|
||||||
|
acc += a @ b;
|
||||||
a = *?(checka)pa;
|
a = *?(checka)pa;
|
||||||
b = *?(checkb)pb;
|
b = *?(checkb)pb;
|
||||||
}
|
}
|
||||||
@@ -56,8 +65,8 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
TYPE c[TM, TN] = acc;
|
TYPE c[TM, TN] = acc;
|
||||||
|
|
||||||
// epilogue
|
// epilogue
|
||||||
int rxm[TM] = ridx * TM + 0 ... TM;
|
int rxm[TM] = pidm * TM + 0 ... TM;
|
||||||
int rxn[TN] = ridy * TN + 0 ... TN;
|
int rxn[TN] = pidn * TN + 0 ... TN;
|
||||||
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
||||||
TYPE* pc[TM, TN] = C + offc;
|
TYPE* pc[TM, TN] = C + offc;
|
||||||
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
||||||
@@ -66,7 +75,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
*?(checkc) pc = c;
|
*?(checkc) pc = c;
|
||||||
#else
|
#else
|
||||||
// accumulate partial result using spin-locks
|
// accumulate partial result using spin-locks
|
||||||
int *plock = locks + rid;
|
int *plock = locks + pid;
|
||||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||||
int count = *pcount;
|
int count = *pcount;
|
||||||
@@ -100,7 +109,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
||||||
'TM' : [128],
|
'TM' : [128],
|
||||||
'TN' : [128],
|
'TN' : [128],
|
||||||
'TK' : [16],
|
'TK' : [32],
|
||||||
'TZ' : [1]
|
'TZ' : [1]
|
||||||
}
|
}
|
||||||
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
|
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
|
||||||
@@ -109,9 +118,10 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
M, K = a.shape
|
M, K = a.shape
|
||||||
K, N = b.shape
|
K, N = b.shape
|
||||||
c = torch.empty([M,N], dtype=dtype, device=a.device)
|
c = torch.empty([M,N], dtype=dtype, device=a.device)
|
||||||
|
print(kernel.asm('sass', c.device))
|
||||||
|
print(kernel.asm('ptx', c.device))
|
||||||
# enqueue
|
# enqueue
|
||||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
|
grid = lambda opt: [triton.cdiv(M, opt.d('TM'))*triton.cdiv(N, opt.d('TN'))]
|
||||||
triton.cdiv(N, opt.d('TN'))]
|
|
||||||
time = kernel(a, b, c, 1., M, N, K,
|
time = kernel(a, b, c, 1., M, N, K,
|
||||||
a.stride(0), b.stride(0), c.stride(0), grid=grid)
|
a.stride(0), b.stride(0), c.stride(0), grid=grid)
|
||||||
return c
|
return c
|
||||||
@@ -130,6 +140,4 @@ b = torch.rand((K, N)).cuda().half()
|
|||||||
|
|
||||||
zc = torch.matmul(a,b)
|
zc = torch.matmul(a,b)
|
||||||
zc_ = dot(a,b)
|
zc_ = dot(a,b)
|
||||||
|
|
||||||
|
|
||||||
print(torch.allclose(zc, zc_))
|
print(torch.allclose(zc, zc_))
|
@@ -111,7 +111,7 @@ setup(
|
|||||||
author_email='ptillet@g.harvard.edu',
|
author_email='ptillet@g.harvard.edu',
|
||||||
description='A language and compiler for custom Deep Learning operations',
|
description='A language and compiler for custom Deep Learning operations',
|
||||||
long_description='',
|
long_description='',
|
||||||
packages=['triton', 'triton/_C', 'triton/ops'],
|
packages=['triton', 'triton/_C'],
|
||||||
install_requires=['numpy', 'torch', 'sympy'],
|
install_requires=['numpy', 'torch', 'sympy'],
|
||||||
package_data={'': data},
|
package_data={'': data},
|
||||||
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
|
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
|
||||||
|
@@ -38,7 +38,7 @@ void delete_grid(const map_key_t& key) {
|
|||||||
|
|
||||||
void register_fn(const map_key_t& key,
|
void register_fn(const map_key_t& key,
|
||||||
const std::string& src,
|
const std::string& src,
|
||||||
const rt::function::options_space_t& opt) {
|
const rt::options_space_t& opt) {
|
||||||
if(id_fn_map.find(key) == id_fn_map.end())
|
if(id_fn_map.find(key) == id_fn_map.end())
|
||||||
id_fn_map[key].reset(new rt::function(src, opt, ""));
|
id_fn_map[key].reset(new rt::function(src, opt, ""));
|
||||||
}
|
}
|
||||||
@@ -47,9 +47,9 @@ void delete_fn(const map_key_t& key) {
|
|||||||
id_fn_map.erase(key);
|
id_fn_map.erase(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) {
|
std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) {
|
||||||
triton::driver::cu_device device(torch_get_cuda_device(key.second), false);
|
triton::driver::cu_device device(key.second, false);
|
||||||
return id_fn_map[key]->ptx(&device, opt);
|
return id_fn_map[key]->get_asm(mode, &device, opt);
|
||||||
}
|
}
|
||||||
|
|
||||||
void cleanup() {
|
void cleanup() {
|
||||||
@@ -63,7 +63,7 @@ size_t make_op_id() {
|
|||||||
|
|
||||||
/* Function signature */
|
/* Function signature */
|
||||||
void make_module(const std::string& src, ir::module* ir,
|
void make_module(const std::string& src, ir::module* ir,
|
||||||
const runtime::function::options_space_t& opt) {
|
const runtime::options_space_t& opt) {
|
||||||
std::string copy = triton::runtime::function::preheader() + src;
|
std::string copy = triton::runtime::function::preheader() + src;
|
||||||
// pre-process
|
// pre-process
|
||||||
TokenSequence tokens;
|
TokenSequence tokens;
|
||||||
@@ -80,7 +80,7 @@ void make_module(const std::string& src, ir::module* ir,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
||||||
const runtime::function::options_space_t& opt) {
|
const runtime::options_space_t& opt) {
|
||||||
// triton-ir code-gen
|
// triton-ir code-gen
|
||||||
ir::context ctx;
|
ir::context ctx;
|
||||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||||
@@ -95,8 +95,8 @@ std::vector<rt::arg_type> get_fn_signature(const std::string& src,
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef triton::runtime::function::options_t options_t;
|
typedef triton::runtime::options_t options_t;
|
||||||
typedef triton::runtime::function::options_space_t options_space_t;
|
typedef triton::runtime::options_space_t options_space_t;
|
||||||
|
|
||||||
PYBIND11_MODULE(libtriton, m) {
|
PYBIND11_MODULE(libtriton, m) {
|
||||||
m.doc() = "Python bindings to the C++ Triton API";
|
m.doc() = "Python bindings to the C++ Triton API";
|
||||||
@@ -113,6 +113,10 @@ PYBIND11_MODULE(libtriton, m) {
|
|||||||
.value("double", rt::DOUBLE_T)
|
.value("double", rt::DOUBLE_T)
|
||||||
.value("buffer", rt::BUFFER_T);
|
.value("buffer", rt::BUFFER_T);
|
||||||
|
|
||||||
|
pybind11::enum_<rt::asm_mode_t>(m, "asm_mode")
|
||||||
|
.value("ptx", rt::ASM_NV_PTX)
|
||||||
|
.value("sass", rt::ASM_NV_SASS);
|
||||||
|
|
||||||
pybind11::class_<options_t>(m, "options")
|
pybind11::class_<options_t>(m, "options")
|
||||||
.def(pybind11::init<>())
|
.def(pybind11::init<>())
|
||||||
.def("d", &options_t::D<int>)
|
.def("d", &options_t::D<int>)
|
||||||
@@ -126,7 +130,7 @@ PYBIND11_MODULE(libtriton, m) {
|
|||||||
|
|
||||||
// hooks into triton constructs since frameworks may not use pybind11
|
// hooks into triton constructs since frameworks may not use pybind11
|
||||||
m.def("get_fn_signature", &get_fn_signature);
|
m.def("get_fn_signature", &get_fn_signature);
|
||||||
m.def("get_fn_ptx", &get_fn_ptx);
|
m.def("get_fn_asm", &get_fn_asm);
|
||||||
m.def("register_grid", ®ister_grid);
|
m.def("register_grid", ®ister_grid);
|
||||||
m.def("delete_grid", &delete_grid);
|
m.def("delete_grid", &delete_grid);
|
||||||
m.def("register_fn", ®ister_fn);
|
m.def("register_fn", ®ister_fn);
|
||||||
|
@@ -59,16 +59,12 @@ void synchronize(int64_t dev_id) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor raw_like(torch::Tensor x){
|
torch::Tensor cuda_empty_like(torch::Tensor x){
|
||||||
if(x.nbytes() == 0)
|
if(x.nbytes() == 0)
|
||||||
return torch::empty_like(x);
|
return torch::empty_like(x);
|
||||||
C10_CUDA_CHECK(cudaSetDevice(x.device().index()));
|
void* data;
|
||||||
auto shape = x.sizes();
|
cudaMalloc(&data, x.nbytes());
|
||||||
CUdeviceptr data;
|
auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options());
|
||||||
triton::driver::dispatch::cuMemAlloc(&data, x.nbytes());
|
|
||||||
auto deleter = [data](void* ptr) { triton::driver::dispatch::cuMemFree_v2(data); };
|
|
||||||
auto ret = torch::from_blob((void*)data, shape, deleter, x.options());
|
|
||||||
ret.copy_(x);
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,6 +90,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
|||||||
|
|
||||||
static auto registry = torch::RegisterOperators()
|
static auto registry = torch::RegisterOperators()
|
||||||
.op("triton::launch_kernel", &launch_kernel)
|
.op("triton::launch_kernel", &launch_kernel)
|
||||||
.op("triton::raw_like", &raw_like)
|
.op("triton::cuda_empty_like", &cuda_empty_like)
|
||||||
.op("triton::cdiv_sum", &cdiv_sum)
|
.op("triton::cdiv_sum", &cdiv_sum)
|
||||||
.op("triton::synchronize", &synchronize);
|
.op("triton::synchronize", &synchronize);
|
||||||
|
@@ -1,7 +1,4 @@
|
|||||||
from .kernel import *
|
from .kernel import *
|
||||||
import triton.ops
|
|
||||||
#import triton.nn
|
|
||||||
|
|
||||||
|
|
||||||
# clean-up libtriton resources
|
# clean-up libtriton resources
|
||||||
import atexit
|
import atexit
|
||||||
|
@@ -68,8 +68,17 @@ class kernel:
|
|||||||
size = sum([sizes[x] for x in arg_types])
|
size = sum([sizes[x] for x in arg_types])
|
||||||
self.tys = ''.join([codes[x] for x in arg_types])
|
self.tys = ''.join([codes[x] for x in arg_types])
|
||||||
|
|
||||||
def ptx(self, device, **kwargs):
|
def asm(self, mode, device, **kwargs):
|
||||||
dev_id = device.index
|
dev_id = device.index
|
||||||
|
# assembly mode
|
||||||
|
supported = {
|
||||||
|
'ptx': libtriton.asm_mode.ptx,
|
||||||
|
'sass': libtriton.asm_mode.sass,
|
||||||
|
}
|
||||||
|
if mode not in supported:
|
||||||
|
raise('ASM mode must be in ', supported.keys())
|
||||||
|
mode = supported[mode]
|
||||||
|
# disambiguates #defines
|
||||||
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
|
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
|
||||||
def _single_value_or_err(x, key):
|
def _single_value_or_err(x, key):
|
||||||
if isinstance(x, list) and len(x) == 1:
|
if isinstance(x, list) and len(x) == 1:
|
||||||
@@ -86,15 +95,18 @@ class kernel:
|
|||||||
opt = libtriton.options()
|
opt = libtriton.options()
|
||||||
opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps')
|
opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps')
|
||||||
opt.defines = defines
|
opt.defines = defines
|
||||||
return libtriton.get_fn_ptx((self.op_id, dev_id), opt)
|
# run
|
||||||
|
return libtriton.get_fn_asm((self.op_id, dev_id), mode, opt)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
if 'TRITON_DEBUG_MODE' in os.environ:
|
if 'TRITON_DEBUG_MODE' in os.environ:
|
||||||
_args = args
|
_args = args
|
||||||
args = [x for x in args]
|
args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args]
|
||||||
for i in range(len(args)):
|
for i in range(len(args)):
|
||||||
if isinstance(args[i], torch.Tensor):
|
if isinstance(args[i], torch.Tensor):
|
||||||
args[i] = torch.ops.triton.raw_like(args[i])
|
args[i] = torch.ops.triton.cuda_empty_like(args[i])
|
||||||
|
args[i].copy_(_args[i])
|
||||||
|
torch.cuda.synchronize()
|
||||||
for x in args:
|
for x in args:
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
device = x.device.index
|
device = x.device.index
|
||||||
@@ -116,6 +128,8 @@ class kernel:
|
|||||||
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
|
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
|
||||||
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
|
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
|
||||||
if 'TRITON_DEBUG_MODE' in os.environ:
|
if 'TRITON_DEBUG_MODE' in os.environ:
|
||||||
|
torch.cuda.synchronize()
|
||||||
for i in range(len(args)):
|
for i in range(len(args)):
|
||||||
if isinstance(args[i], torch.Tensor):
|
if isinstance(args[i], torch.Tensor):
|
||||||
_args[i].copy_(args[i])
|
_args[i].copy_(args[i].clone())
|
||||||
|
args = _args
|
@@ -1,2 +0,0 @@
|
|||||||
from .conv import replace_conv2d
|
|
||||||
from .attention import replace_mah
|
|
@@ -1,312 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import triton
|
|
||||||
|
|
||||||
def bmm(x, w, mask = None):
|
|
||||||
b, m, k = x.size()
|
|
||||||
b, k, n = w.size()
|
|
||||||
out = torch.empty([b, m, n], device=x.device)
|
|
||||||
triton.ops.einsum('bmk,bkn->bmn', x, w, out, mask=mask, bench=False)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def multi_head_attention_forward(query, # type: Tensor
|
|
||||||
key, # type: Tensor
|
|
||||||
value, # type: Tensor
|
|
||||||
embed_dim_to_check, # type: int
|
|
||||||
num_heads, # type: int
|
|
||||||
in_proj_weight, # type: Tensor
|
|
||||||
in_proj_bias, # type: Tensor
|
|
||||||
bias_k, # type: Optional[Tensor]
|
|
||||||
bias_v, # type: Optional[Tensor]
|
|
||||||
add_zero_attn, # type: bool
|
|
||||||
dropout_p, # type: float
|
|
||||||
out_proj_weight, # type: Tensor
|
|
||||||
out_proj_bias, # type: Tensor
|
|
||||||
training=True, # type: bool
|
|
||||||
key_padding_mask=None, # type: Optional[Tensor]
|
|
||||||
need_weights=True, # type: bool
|
|
||||||
attn_mask=None, # type: Optional[Tensor]
|
|
||||||
use_separate_proj_weight=False, # type: bool
|
|
||||||
q_proj_weight=None, # type: Optional[Tensor]
|
|
||||||
k_proj_weight=None, # type: Optional[Tensor]
|
|
||||||
v_proj_weight=None, # type: Optional[Tensor]
|
|
||||||
static_k=None, # type: Optional[Tensor]
|
|
||||||
static_v=None, # type: Optional[Tensor]
|
|
||||||
acc_bitmask=None
|
|
||||||
):
|
|
||||||
# type: (...) -> Tuple[Tensor, Optional[Tensor]]
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
query, key, value: map a query and a set of key-value pairs to an output.
|
|
||||||
See "Attention Is All You Need" for more details.
|
|
||||||
embed_dim_to_check: total dimension of the model.
|
|
||||||
num_heads: parallel attention heads.
|
|
||||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
|
||||||
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
|
||||||
add_zero_attn: add a new batch of zeros to the key and
|
|
||||||
value sequences at dim=1.
|
|
||||||
dropout_p: probability of an element to be zeroed.
|
|
||||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
|
||||||
training: apply dropout if is ``True``.
|
|
||||||
key_padding_mask: if provided, specified padding elements in the key will
|
|
||||||
be ignored by the attention. This is an binary mask. When the value is True,
|
|
||||||
the corresponding value on the attention layer will be filled with -inf.
|
|
||||||
need_weights: output attn_output_weights.
|
|
||||||
attn_mask: mask that prevents attention to certain positions. This is an additive mask
|
|
||||||
(i.e. the values will be added to the attention layer).
|
|
||||||
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
|
||||||
and value in differnt forms. If false, in_proj_weight will be used, which is
|
|
||||||
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
|
||||||
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
|
||||||
static_k, static_v: static key and value used for attention operators.
|
|
||||||
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
Inputs:
|
|
||||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
|
||||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
|
||||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
|
||||||
the embedding dimension.
|
|
||||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
|
||||||
- attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
|
||||||
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
|
||||||
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
|
||||||
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
|
||||||
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
|
||||||
E is the embedding dimension.
|
|
||||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
|
||||||
L is the target sequence length, S is the source sequence length.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tgt_len, bsz, embed_dim = query.size()
|
|
||||||
assert embed_dim == embed_dim_to_check
|
|
||||||
assert key.size() == value.size()
|
|
||||||
|
|
||||||
head_dim = embed_dim // num_heads
|
|
||||||
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
|
||||||
scaling = float(head_dim) ** -0.5
|
|
||||||
|
|
||||||
|
|
||||||
if not use_separate_proj_weight:
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
|
||||||
# self-attention
|
|
||||||
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
|
||||||
# encoder-decoder attention
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = 0
|
|
||||||
_end = embed_dim
|
|
||||||
_w = in_proj_weight[_start:_end, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:_end]
|
|
||||||
q = F.linear(query, _w, _b)
|
|
||||||
|
|
||||||
if key is None:
|
|
||||||
assert value is None
|
|
||||||
k = None
|
|
||||||
v = None
|
|
||||||
else:
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = embed_dim
|
|
||||||
_end = None
|
|
||||||
_w = in_proj_weight[_start:, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:]
|
|
||||||
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = 0
|
|
||||||
_end = embed_dim
|
|
||||||
_w = in_proj_weight[_start:_end, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:_end]
|
|
||||||
q = F.linear(query, _w, _b)
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = embed_dim
|
|
||||||
_end = embed_dim * 2
|
|
||||||
_w = in_proj_weight[_start:_end, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:_end]
|
|
||||||
k = F.linear(key, _w, _b)
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
|
||||||
_b = in_proj_bias
|
|
||||||
_start = embed_dim * 2
|
|
||||||
_end = None
|
|
||||||
_w = in_proj_weight[_start:, :]
|
|
||||||
if _b is not None:
|
|
||||||
_b = _b[_start:]
|
|
||||||
v = F.linear(value, _w, _b)
|
|
||||||
else:
|
|
||||||
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
|
||||||
len1, len2 = q_proj_weight_non_opt.size()
|
|
||||||
assert len1 == embed_dim and len2 == query.size(-1)
|
|
||||||
|
|
||||||
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
|
||||||
len1, len2 = k_proj_weight_non_opt.size()
|
|
||||||
assert len1 == embed_dim and len2 == key.size(-1)
|
|
||||||
|
|
||||||
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
|
||||||
len1, len2 = v_proj_weight_non_opt.size()
|
|
||||||
assert len1 == embed_dim and len2 == value.size(-1)
|
|
||||||
|
|
||||||
if in_proj_bias is not None:
|
|
||||||
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
|
||||||
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
|
|
||||||
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
|
|
||||||
else:
|
|
||||||
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
|
|
||||||
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
|
|
||||||
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
|
|
||||||
q = q * scaling
|
|
||||||
|
|
||||||
if bias_k is not None and bias_v is not None:
|
|
||||||
if static_k is None and static_v is None:
|
|
||||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
|
||||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
|
||||||
if attn_mask is not None:
|
|
||||||
attn_mask = torch.cat([attn_mask,
|
|
||||||
torch.zeros((attn_mask.size(0), 1),
|
|
||||||
dtype=attn_mask.dtype,
|
|
||||||
device=attn_mask.device)], dim=1)
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
key_padding_mask = torch.cat(
|
|
||||||
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
|
|
||||||
dtype=key_padding_mask.dtype,
|
|
||||||
device=key_padding_mask.device)], dim=1)
|
|
||||||
else:
|
|
||||||
assert static_k is None, "bias cannot be added to static key."
|
|
||||||
assert static_v is None, "bias cannot be added to static value."
|
|
||||||
else:
|
|
||||||
assert bias_k is None
|
|
||||||
assert bias_v is None
|
|
||||||
|
|
||||||
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
|
||||||
if k is not None:
|
|
||||||
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
|
||||||
if v is not None:
|
|
||||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
|
||||||
|
|
||||||
if static_k is not None:
|
|
||||||
assert static_k.size(0) == bsz * num_heads
|
|
||||||
assert static_k.size(2) == head_dim
|
|
||||||
k = static_k
|
|
||||||
|
|
||||||
if static_v is not None:
|
|
||||||
assert static_v.size(0) == bsz * num_heads
|
|
||||||
assert static_v.size(2) == head_dim
|
|
||||||
v = static_v
|
|
||||||
|
|
||||||
src_len = k.size(1)
|
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
assert key_padding_mask.size(0) == bsz
|
|
||||||
assert key_padding_mask.size(1) == src_len
|
|
||||||
|
|
||||||
if add_zero_attn:
|
|
||||||
src_len += 1
|
|
||||||
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
|
||||||
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
|
||||||
if attn_mask is not None:
|
|
||||||
attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1),
|
|
||||||
dtype=attn_mask.dtype,
|
|
||||||
device=attn_mask.device)], dim=1)
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
key_padding_mask = torch.cat(
|
|
||||||
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
|
|
||||||
dtype=key_padding_mask.dtype,
|
|
||||||
device=key_padding_mask.device)], dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
attn_output_weights = bmm(q, k.transpose(1, 2), mask=acc_bitmask)
|
|
||||||
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
|
||||||
attn_mask = attn_mask.unsqueeze(0)
|
|
||||||
attn_output_weights += attn_mask
|
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
|
||||||
attn_output_weights = attn_output_weights.masked_fill(
|
|
||||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
|
||||||
float('-inf'),
|
|
||||||
)
|
|
||||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
|
||||||
|
|
||||||
attn_output_weights = F.softmax(
|
|
||||||
attn_output_weights, dim=-1)
|
|
||||||
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
|
|
||||||
|
|
||||||
attn_output = bmm(attn_output_weights, v, mask=acc_bitmask)
|
|
||||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
|
||||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
|
||||||
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
|
||||||
|
|
||||||
if need_weights:
|
|
||||||
# average attention weights over heads
|
|
||||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
|
||||||
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
|
||||||
else:
|
|
||||||
return attn_output, None
|
|
||||||
|
|
||||||
class MultiheadAttention(nn.modules.activation.MultiheadAttention):
|
|
||||||
|
|
||||||
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, acc_bitmask=None):
|
|
||||||
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim)
|
|
||||||
self.acc_bitmask = acc_bitmask
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, query, key, value, key_padding_mask=None,
|
|
||||||
need_weights=True, attn_mask=None):
|
|
||||||
if not self._qkv_same_embed_dim:
|
|
||||||
return multi_head_attention_forward(
|
|
||||||
query, key, value, self.embed_dim, self.num_heads,
|
|
||||||
self.in_proj_weight, self.in_proj_bias,
|
|
||||||
self.bias_k, self.bias_v, self.add_zero_attn,
|
|
||||||
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
|
||||||
training=self.training,
|
|
||||||
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
|
||||||
attn_mask=attn_mask, use_separate_proj_weight=True,
|
|
||||||
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
|
||||||
v_proj_weight=self.v_proj_weight,
|
|
||||||
acc_bitmask=self.acc_bitmask)
|
|
||||||
else:
|
|
||||||
return multi_head_attention_forward(
|
|
||||||
query, key, value, self.embed_dim, self.num_heads,
|
|
||||||
self.in_proj_weight, self.in_proj_bias,
|
|
||||||
self.bias_k, self.bias_v, self.add_zero_attn,
|
|
||||||
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
|
||||||
training=self.training,
|
|
||||||
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
|
||||||
attn_mask=attn_mask,
|
|
||||||
acc_bitmask=self.acc_bitmask)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_mah(model, mask = None):
|
|
||||||
for child_name, child in model.named_children():
|
|
||||||
if isinstance(child, nn.modules.activation.MultiheadAttention):
|
|
||||||
add_bias_kv = child.bias_k is not None
|
|
||||||
device = child.in_proj_weight.device
|
|
||||||
mah = MultiheadAttention(child.embed_dim, child.num_heads,
|
|
||||||
dropout=child.dropout, add_bias_kv=add_bias_kv,
|
|
||||||
add_zero_attn=child.add_zero_attn, kdim=child.kdim,
|
|
||||||
vdim=child.vdim, acc_bitmask=mask).to(device)
|
|
||||||
for yparam, xparam in zip(mah.parameters(), child.parameters()):
|
|
||||||
yparam.data.copy_(xparam.data)
|
|
||||||
setattr(model, child_name, mah)
|
|
||||||
else:
|
|
||||||
replace_mah(child, mask)
|
|
@@ -1,166 +0,0 @@
|
|||||||
import triton
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
class _conv2d(torch.autograd.Function):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x, weight, bias,
|
|
||||||
stride, padding, dilation, groups,
|
|
||||||
acc_bitmask):
|
|
||||||
assert dilation == (1, 1)
|
|
||||||
assert groups == 1
|
|
||||||
assert bias == None
|
|
||||||
pad_h, pad_w = padding
|
|
||||||
stride_h, stride_w = stride
|
|
||||||
n, c, h, w = x.size()
|
|
||||||
k, c, r, s = weight.size()
|
|
||||||
# allocate output
|
|
||||||
p = (h + 2*padding[0] - r)//stride[0] + 1
|
|
||||||
q = (w + 2*padding[1] - s)//stride[1] + 1
|
|
||||||
output = torch.empty((n, k, p, q), dtype=x.dtype, device=x.device)
|
|
||||||
# padding
|
|
||||||
if pad_h or pad_w:
|
|
||||||
x = triton.ops._einsum.pad(x, [pad_w, pad_w, pad_h, pad_h])
|
|
||||||
# convolution
|
|
||||||
triton.ops.einsum(f'nc(h*stride_h + r - pad_h)(w*stride_w + s - pad_w),kcrs->nkhw',
|
|
||||||
x, weight, mask=acc_bitmask,
|
|
||||||
output=output,
|
|
||||||
values = {'pad_h': pad_h,
|
|
||||||
'stride_h': stride_h,
|
|
||||||
'pad_w': pad_w,
|
|
||||||
'stride_w': stride_w})
|
|
||||||
# prepare backprop
|
|
||||||
ctx.save_for_backward(x, weight)
|
|
||||||
ctx.stride = stride
|
|
||||||
ctx.padding = padding
|
|
||||||
ctx.acc_bitmask = acc_bitmask
|
|
||||||
# return
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, dy):
|
|
||||||
# retrieve contextual information
|
|
||||||
x, weight = ctx.saved_tensors
|
|
||||||
stride = ctx.stride
|
|
||||||
padding = ctx.padding
|
|
||||||
acc_bitmask = ctx.acc_bitmask
|
|
||||||
# gradient of the input
|
|
||||||
dx = None
|
|
||||||
if ctx.needs_input_grad[0]:
|
|
||||||
# dy must be padded
|
|
||||||
n, k, p, q = dy.size()
|
|
||||||
n, c, h, w = x.size()
|
|
||||||
k, c, r, s = weight.size()
|
|
||||||
dypad = triton.ops._einsum.pad(dy, [4, 4, 4, 4])
|
|
||||||
# have to be careful here
|
|
||||||
# the gradient of strided conv is a conv over a sparse image
|
|
||||||
# which can be decomposed as a set of smaller convs
|
|
||||||
dx = torch.empty_like(x)
|
|
||||||
for offh in range(stride[0]):
|
|
||||||
for offw in range(stride[1]):
|
|
||||||
poffh = (offh + padding[0]) % stride[0]
|
|
||||||
poffw = (offw + padding[1]) % stride[1]
|
|
||||||
pad_h = int((padding[0] + (stride[0] - 1)*offh) / stride[0])
|
|
||||||
pad_w = int((padding[1] + (stride[1] - 1)*offw) / stride[1])
|
|
||||||
if poffh >= r or poffw >= s:
|
|
||||||
dx[:, :, offh::stride[0], offw::stride[1]] = 0
|
|
||||||
else:
|
|
||||||
triton.ops.einsum(f'nk(h - r + pad_h)(w - s + pad_w),kcrs->nchw',
|
|
||||||
dypad[:, :, :, :],
|
|
||||||
weight[:, :, poffh::stride[0], poffw::stride[1]],
|
|
||||||
output = dx[:, :, offh::stride[0], offw::stride[1]],
|
|
||||||
mask = acc_bitmask,
|
|
||||||
values = {'pad_h': pad_h,
|
|
||||||
'pad_w': pad_w})
|
|
||||||
|
|
||||||
# gradient for the weight
|
|
||||||
dw = None
|
|
||||||
if ctx.needs_input_grad[1]:
|
|
||||||
dw = torch.empty_like(weight)
|
|
||||||
triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs',
|
|
||||||
x, dy, output = dw, mask = acc_bitmask)
|
|
||||||
#print('dw: ', dw.view(-1)[0])
|
|
||||||
return dx, dw, None, None, None, None, None, None
|
|
||||||
conv2d = _conv2d.apply
|
|
||||||
|
|
||||||
class Conv2d(nn.Conv2d):
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
|
||||||
padding=0, dilation=1, groups=1,
|
|
||||||
bias=True, padding_mode='zeros',
|
|
||||||
acc_bitmask = None):
|
|
||||||
super(Conv2d, self).__init__(
|
|
||||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
|
||||||
groups, bias, padding_mode)
|
|
||||||
self.acc_bitmask = acc_bitmask
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
#if self.kernel_size[0] == 3 and self.stride[0] != 1:
|
|
||||||
#print(self.padding, self.stride, input.size(), self.weight.size())
|
|
||||||
# return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
||||||
return conv2d(input, self.weight, self.bias, self.stride,
|
|
||||||
self.padding, self.dilation, self.groups,
|
|
||||||
self.acc_bitmask)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_conv2d(model, acc_bitmask = None):
|
|
||||||
for child_name, child in model.named_children():
|
|
||||||
if isinstance(child, nn.Conv2d):
|
|
||||||
conv2d = Conv2d(child.in_channels, child.out_channels, child.kernel_size,
|
|
||||||
child.stride, child.padding, child.dilation, child.groups,
|
|
||||||
child.bias, child.padding_mode,
|
|
||||||
acc_bitmask=acc_bitmask)
|
|
||||||
for yparam, xparam in zip(conv2d.parameters(), child.parameters()):
|
|
||||||
yparam.data.copy_(xparam.data)
|
|
||||||
setattr(model, child_name, conv2d)
|
|
||||||
else:
|
|
||||||
replace_conv2d(child, acc_bitmask)
|
|
||||||
|
|
||||||
# initialize input
|
|
||||||
#N, C, H, W, K, RS = 16, 32, 24, 24, 64, 3
|
|
||||||
#torch.Size([128, 64, 30, 30]) torch.Size([128, 64, 3, 3])
|
|
||||||
#torch.Size([128, 128, 15, 15]) torch.Size([256, 128, 3, 3])
|
|
||||||
#torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3])
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
N, C, H, W, K, RS = 128, 64, 30, 30, 128, 3
|
|
||||||
#N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3
|
|
||||||
#N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3
|
|
||||||
pad, stride = 0, 1
|
|
||||||
torch.manual_seed(0)
|
|
||||||
x = torch.randn((N, C, H, W)).cuda()
|
|
||||||
x.requires_grad_(True)
|
|
||||||
#x.data[:] = 1
|
|
||||||
# initialize layers
|
|
||||||
torch.manual_seed(0)
|
|
||||||
rconv2d = nn.Conv2d(C, K, RS, stride, pad, bias=False).cuda()
|
|
||||||
torch.manual_seed(0)
|
|
||||||
tconv2d = Conv2d(C, K, RS, stride, pad, bias=False).cuda()
|
|
||||||
#rconv2d.weight.data[:] = 1
|
|
||||||
#tconv2d.weight.data[:] = 1
|
|
||||||
ry = rconv2d(x)
|
|
||||||
ty = tconv2d(x)
|
|
||||||
# reference
|
|
||||||
dy = torch.randn(ry.size()).cuda()
|
|
||||||
#dy.data[:] = 1
|
|
||||||
ry.backward(dy)
|
|
||||||
rdx = x.grad.clone()
|
|
||||||
rdw = rconv2d.weight.grad.clone()
|
|
||||||
x.grad.zero_()
|
|
||||||
# triton
|
|
||||||
ty.backward(dy)
|
|
||||||
tdx = x.grad.clone()
|
|
||||||
tdw = tconv2d.weight.grad.clone()
|
|
||||||
x.grad.zero_()
|
|
||||||
# print error
|
|
||||||
diff = lambda x, y: (x - y).abs().max()
|
|
||||||
print(diff(ry, ty))
|
|
||||||
print(diff(rdx, tdx))
|
|
||||||
print(diff(rdw, tdw))
|
|
||||||
#print((rdx - tdx).abs())
|
|
||||||
|
|
||||||
#print((rdx[0,0,:,:] - tdx[0,0,:,:]))
|
|
||||||
#print(rdx[0,0,:,:])
|
|
||||||
#print(tdx[0,0,:,:])
|
|
@@ -1,13 +0,0 @@
|
|||||||
import torch
|
|
||||||
import triton
|
|
||||||
|
|
||||||
|
|
||||||
def linear(x, w, bias = None):
|
|
||||||
print(x.size(), w.size())
|
|
||||||
m, k = x.size()
|
|
||||||
k, n = w.size()
|
|
||||||
out = torch.empty([m, n], device=x.device)
|
|
||||||
triton.ops.einsum('mk,nk->mn', x, w, bias)
|
|
||||||
if bias is not None:
|
|
||||||
out += bias
|
|
||||||
return out
|
|
@@ -1,2 +0,0 @@
|
|||||||
from .einsum import _einsum, einsum
|
|
||||||
from .batchnorm import _batchnorm, batchnorm
|
|
@@ -1,136 +0,0 @@
|
|||||||
import triton
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
|
|
||||||
class _batchnorm(torch.autograd.Function):
|
|
||||||
|
|
||||||
fwd_src = """
|
|
||||||
void fwdbatchnorm(float *Y, float *M, float *V,
|
|
||||||
float *X, float *G, float *B,
|
|
||||||
int N, float eps) {
|
|
||||||
// pointers
|
|
||||||
int c = get_program_id(1);
|
|
||||||
int rm[TM] = 0 ... TM;
|
|
||||||
float *px[TM] = X + rm + c*N;
|
|
||||||
float* py[TM] = Y + rm + c*N;
|
|
||||||
|
|
||||||
// compute mean
|
|
||||||
float accm[TM] = 0;
|
|
||||||
for(int i = 0; i < N; i = i + TM)
|
|
||||||
accm = accm + *(px + i);
|
|
||||||
float mean = (float)accm[+] / N;
|
|
||||||
*(M + c) = mean;
|
|
||||||
|
|
||||||
// compute variance
|
|
||||||
float accv[TM] = 0;
|
|
||||||
for(int i = 0; i < N; i = i + TM){
|
|
||||||
float x[TM] = *(px + i);
|
|
||||||
x = x - mean;
|
|
||||||
accv = accv + x*x;
|
|
||||||
}
|
|
||||||
float var = (float)accv[+] / N;
|
|
||||||
*(V + c) = var;
|
|
||||||
|
|
||||||
// Normalize batch
|
|
||||||
float gamma = *(G + c);
|
|
||||||
float beta = *(B + c);
|
|
||||||
float rstdg = 1 / sqrtf(var + eps) * gamma;
|
|
||||||
for(int i = 0; i < N; i = i + TM){
|
|
||||||
float x[TM] = *(px + i);
|
|
||||||
float y[TM] = (x - mean)*rstdg + beta;
|
|
||||||
*(py + i) = y;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
bwd_src = """
|
|
||||||
void bwdbatchnorm(float *DX, float *DG, float *DB,
|
|
||||||
float *DY, float *X, float *G,
|
|
||||||
float *M, float *V,
|
|
||||||
int N, float epsilon) {
|
|
||||||
|
|
||||||
// pointers
|
|
||||||
int c = get_program_id(1);
|
|
||||||
int rx[TM] = 0 ... TM;
|
|
||||||
int offset = c*N;
|
|
||||||
float* px[TM] = X + rx + offset;
|
|
||||||
float* pdy[TM] = DY + rx + offset;
|
|
||||||
float* pdx[TM] = DX + rx + offset;
|
|
||||||
|
|
||||||
// fetch statistics
|
|
||||||
float gamma = *(G + c);
|
|
||||||
float mean = *(M + c);
|
|
||||||
float var = *(V + c);
|
|
||||||
float rstd = 1 / sqrtf(var + epsilon);
|
|
||||||
|
|
||||||
// compute dgamma and dbeta
|
|
||||||
float acc_dg[TM] = 0;
|
|
||||||
float acc_db[TM] = 0;
|
|
||||||
for(int i = 0; i < N; i = i + TM){
|
|
||||||
float x[TM] = *(px + i);
|
|
||||||
float dy[TM] = *(pdy + i);
|
|
||||||
acc_dg += dy*(x - mean)*rstd;
|
|
||||||
acc_db += dy;
|
|
||||||
}
|
|
||||||
float dg = acc_dg[+];
|
|
||||||
float db = acc_db[+];
|
|
||||||
*(DG + c) = dg;
|
|
||||||
*(DB + c) = db;
|
|
||||||
|
|
||||||
// compute dx
|
|
||||||
for(int i = 0; i < N; i = i + TM){
|
|
||||||
float x[TM] = *(px + i);
|
|
||||||
float dy[TM] = *(pdy + i);
|
|
||||||
float xhat[TM] = (x - mean) * rstd;
|
|
||||||
float xtmp[TM] = (xhat * dg + db) / N;
|
|
||||||
float dx[TM] = (dy - xtmp) * rstd * gamma;
|
|
||||||
*(pdx + i) = dx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
fwd_kernel = None
|
|
||||||
bwd_kernel = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x, gamma, beta, eps):
|
|
||||||
# lazy compilation of kernel
|
|
||||||
if _batchnorm.fwd_kernel is None:
|
|
||||||
_batchnorm.fwd_kernel = triton.kernel(fwd_src, defines = {'TM': 128})
|
|
||||||
# shapes
|
|
||||||
shape = triton.shape(x)
|
|
||||||
dtype = x.dtype
|
|
||||||
# allocate outputs
|
|
||||||
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
|
|
||||||
y = triton.empty(shape, dtype=dtype)
|
|
||||||
mean = triton.empty([C], dtype=dtype)
|
|
||||||
var = triton.empty([C], dtype=dtype)
|
|
||||||
# execute kernels
|
|
||||||
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
|
|
||||||
grid = lambda opt: [1, C])
|
|
||||||
# save
|
|
||||||
ctx.save_for_backward(x, gamma, beta, mean, var)
|
|
||||||
ctx.eps = eps
|
|
||||||
return y
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, dy):
|
|
||||||
# lazy compilation of kernel
|
|
||||||
if _batchnorm.bwd_kernel is None:
|
|
||||||
_batchnorm.bwd_kernel = triton.kernel(bwd_src, defines = {'TN': 128})
|
|
||||||
# retrieve info
|
|
||||||
x, gamma, beta, mean, var = ctx.saved_tensors
|
|
||||||
eps = ctx.eps
|
|
||||||
# allocate result
|
|
||||||
dx = triton.empty(triton.shape(x), dtype=x.dtype)
|
|
||||||
dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
|
|
||||||
dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
|
|
||||||
# execute
|
|
||||||
C, H, W, B = triton.shape(x)
|
|
||||||
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
|
|
||||||
x, gamma, mean, var,
|
|
||||||
H*W*B, eps,
|
|
||||||
grid = lambda opt: [1, C])
|
|
||||||
return dx, dgamma, dbeta, None
|
|
||||||
|
|
||||||
batchnorm = _batchnorm.apply
|
|
@@ -1,794 +0,0 @@
|
|||||||
from math import ceil, log2
|
|
||||||
from enum import IntEnum
|
|
||||||
from functools import reduce
|
|
||||||
from operator import mul
|
|
||||||
from collections import OrderedDict
|
|
||||||
from collections import namedtuple
|
|
||||||
import re
|
|
||||||
import string
|
|
||||||
import triton
|
|
||||||
import torch
|
|
||||||
# numpy -- ideally removed in a future release
|
|
||||||
import numpy as np
|
|
||||||
# sympy -- ideally removed in a future release
|
|
||||||
import sympy as sp
|
|
||||||
from sympy.parsing.sympy_parser import parse_expr
|
|
||||||
from sympy.printing.ccode import C89CodePrinter
|
|
||||||
|
|
||||||
|
|
||||||
class _einsum(torch.autograd.Function):
|
|
||||||
|
|
||||||
|
|
||||||
#############################
|
|
||||||
## Triton-C code generation
|
|
||||||
#############################
|
|
||||||
def print_cc(expr, axes_0, axes_1, axes_2, prefix):
|
|
||||||
if expr in axes_0:
|
|
||||||
return f'{prefix}r{expr}[:, newaxis, newaxis]'
|
|
||||||
if expr in axes_1:
|
|
||||||
return f'{prefix}r{expr}[newaxis, :, newaxis]'
|
|
||||||
if expr in axes_2:
|
|
||||||
return f'{prefix}r{expr}[newaxis, newaxis, :]'
|
|
||||||
return expr
|
|
||||||
|
|
||||||
def unpack_cc(tile, axes, prefix, remat):
|
|
||||||
ret = ''
|
|
||||||
axes = list(map(str, axes))
|
|
||||||
for i, d in enumerate(reversed(axes)):
|
|
||||||
if i == len(axes) - 1:
|
|
||||||
break
|
|
||||||
currs = ''.join(axes[: len(axes) - i])
|
|
||||||
nexts = ''.join(axes[: len(axes) - (i + 1)])
|
|
||||||
ty = '' if remat else 'int '
|
|
||||||
sz = '' if remat or tile is None else f'[{tile}]'
|
|
||||||
ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n'
|
|
||||||
ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n'
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def strides_cc(name, expr):
|
|
||||||
ret = [f'stride_{name}_{d}' for d in expr[:-1]] + ['1']
|
|
||||||
ret = dict(zip(expr, ret))
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def make_kernel(name, dtype,
|
|
||||||
expr_a, expr_b, expr_c,
|
|
||||||
sparse_a, sparse_b, sparse_c,
|
|
||||||
axes_m, axes_n, axes_k, axes_b,
|
|
||||||
multipleof_a, multipleof_b, multipleof_c,
|
|
||||||
stride_a_last, stride_b_last, stride_c_last,
|
|
||||||
lut_mode_a, lut_mode_b,
|
|
||||||
delta_a, delta_b,
|
|
||||||
blocks):
|
|
||||||
|
|
||||||
use_lut_a = True
|
|
||||||
use_lut_b = True
|
|
||||||
|
|
||||||
outer_sparse_a = [x for x in expr_a if x in sparse_a and x not in axes_k]
|
|
||||||
outer_dense_a = [x for x in expr_a if x not in sparse_a and x not in axes_k]
|
|
||||||
outer_sparse_b = [x for x in expr_b if x in sparse_b and x not in axes_k]
|
|
||||||
outer_dense_b = [x for x in expr_b if x not in sparse_b and x not in axes_k]
|
|
||||||
outer_dense_c = [x for x in expr_c if x not in sparse_c and x not in axes_k]
|
|
||||||
|
|
||||||
|
|
||||||
src = ""
|
|
||||||
|
|
||||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
|
||||||
src += f"""
|
|
||||||
char __constant__* AD = calloc({4*len(delta_a)});"""
|
|
||||||
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
|
|
||||||
src += f"""
|
|
||||||
char __constant__* BD = calloc({4*len(delta_b)});"""
|
|
||||||
|
|
||||||
|
|
||||||
src += f"""
|
|
||||||
__global__ void {name}(
|
|
||||||
TYPE * A __noalias __readonly __aligned(16)
|
|
||||||
, TYPE * B __noalias __readonly __aligned(16)
|
|
||||||
, TYPE * C
|
|
||||||
, int * locks
|
|
||||||
, float alpha
|
|
||||||
, int matmul_m, int matmul_n, int matmul_k __multipleof(16)
|
|
||||||
, int div_m
|
|
||||||
"""
|
|
||||||
for dim in [axes_m, axes_n, axes_k, axes_b]:
|
|
||||||
for d in dim:
|
|
||||||
src += f", int dim_{d}"
|
|
||||||
src += "\n "
|
|
||||||
for dim, name, mult, sparse in zip([expr_a, expr_b, expr_c],
|
|
||||||
['a', 'b', 'c'],
|
|
||||||
[multipleof_a, multipleof_b, multipleof_c],
|
|
||||||
[sparse_a, sparse_b, sparse_c]):
|
|
||||||
for d in range(len(dim) - 1):
|
|
||||||
if sparse and dim[d] == sparse[0]:
|
|
||||||
src += f', int stride_{name}_block __multipleof({mult})'
|
|
||||||
src += f", int stride_{name}_{d} __multipleof({mult})"
|
|
||||||
src += "\n "
|
|
||||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
|
||||||
src += f", int stride_a_inner __multipleof({multipleof_a})"
|
|
||||||
src += f", int rem_delta_a __multipleof({multipleof_a})"
|
|
||||||
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
|
|
||||||
src += ", int* AD __noalias __readonly __aligned(16)"
|
|
||||||
src += "\n "
|
|
||||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
|
||||||
src += f", int stride_b_inner __multipleof({multipleof_b})"
|
|
||||||
src += f", int rem_delta_b __multipleof({multipleof_b})"
|
|
||||||
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
|
|
||||||
src += ", int* BD"
|
|
||||||
src += "\n "
|
|
||||||
if sparse_c:
|
|
||||||
src += ", int* CD"
|
|
||||||
if sparse_a or sparse_b:
|
|
||||||
src += ", int width"
|
|
||||||
src += """) {
|
|
||||||
|
|
||||||
|
|
||||||
// program identifiers
|
|
||||||
int pid_0 = get_program_id(0);
|
|
||||||
int pid_1 = get_program_id(1);
|
|
||||||
|
|
||||||
"""
|
|
||||||
if sparse_a:
|
|
||||||
src += f"""
|
|
||||||
int off_n = pid_0 / width;
|
|
||||||
int off_header = pid_0 % width;
|
|
||||||
int* header = AD + off_header * {2 + len(outer_sparse_a)};
|
|
||||||
int* pdelta = AD + *(header + 0);
|
|
||||||
matmul_k = *(header + 1);"""
|
|
||||||
for i, d in enumerate(outer_sparse_a):
|
|
||||||
src += f"""
|
|
||||||
int off_{d} = *(header + {2 + i});"""
|
|
||||||
src += f"""
|
|
||||||
int inca = *(pdelta + 0);
|
|
||||||
int incb = *(pdelta + 1);
|
|
||||||
int off_{''.join(map(str, outer_dense_a))} = pid_1;
|
|
||||||
"""
|
|
||||||
_einsum.unpack_cc(None, outer_dense_a, "off_", False)
|
|
||||||
elif sparse_b:
|
|
||||||
src += f"""
|
|
||||||
int off_m = pid_0 / width;
|
|
||||||
int off_header = pid_0 % width;
|
|
||||||
int* header = BD + off_header * {2 + len(outer_sparse_b)};
|
|
||||||
int* pdelta = BD + *(header + 0);
|
|
||||||
matmul_k = *(header + 1);"""
|
|
||||||
for i, d in enumerate(outer_sparse_b):
|
|
||||||
src += f"""
|
|
||||||
int off_{d} = *(header + {2 + i});"""
|
|
||||||
src += f"""
|
|
||||||
int incb = *(pdelta + 0);
|
|
||||||
int inca = *(pdelta + 1);
|
|
||||||
int off_{''.join(map(str, outer_dense_b))} = pid_1;
|
|
||||||
"""
|
|
||||||
_einsum.unpack_cc(None, outer_dense_b, "off_", False)
|
|
||||||
elif sparse_c:
|
|
||||||
src += f"""
|
|
||||||
// load LUT header
|
|
||||||
int *header = CD + pid_0 * {len(sparse_c)};"""
|
|
||||||
for i, d in enumerate(sparse_c):
|
|
||||||
src += f"""
|
|
||||||
int off_{d} = *(header + {i});"""
|
|
||||||
src += f"""
|
|
||||||
int off_{''.join(map(str, outer_dense_c))} = pid_1;"""
|
|
||||||
else:
|
|
||||||
src += """
|
|
||||||
// re-order outer program ids
|
|
||||||
int grid_m = (matmul_m + TM - 1) / TM;
|
|
||||||
int grid_n = (matmul_n + TN - 1) / TN;
|
|
||||||
int off_mn = pid_0 / div_m;
|
|
||||||
int off_n = off_mn % grid_n;
|
|
||||||
int off_m = (off_mn / grid_n)*div_m + (pid_0 % div_m);
|
|
||||||
int off_b = get_program_id(1);"""
|
|
||||||
|
|
||||||
src += """
|
|
||||||
#if TZ == 1
|
|
||||||
int off_k = 0;
|
|
||||||
#else
|
|
||||||
// get reduction sub-group program id
|
|
||||||
int pid_z = get_program_id(2);
|
|
||||||
int grid_z = get_num_programs(2);
|
|
||||||
int div_z = matmul_k / TZ;
|
|
||||||
int rem_z = matmul_k % TZ;
|
|
||||||
int off_k = pid_z * div_z;
|
|
||||||
matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
|
|
||||||
#endif
|
|
||||||
int rem_k = matmul_k % TK;
|
|
||||||
|
|
||||||
// create ranges
|
|
||||||
"""
|
|
||||||
|
|
||||||
sparse = sparse_a + sparse_b + sparse_c
|
|
||||||
for axes, tile, off, prefixes in zip([axes_m, axes_n, axes_b, axes_k],
|
|
||||||
['TM', 'TN', 'TB', 'TK'],
|
|
||||||
['off_m*TM', 'off_n*TN', 'off_b*TB', 'off_k'],
|
|
||||||
[['a', 'c'], ['b', 'c'], ['a', 'b', 'c'], ['a', 'b']]):
|
|
||||||
if not axes:
|
|
||||||
continue
|
|
||||||
currs = ''.join(map(str,axes))
|
|
||||||
has_sparse_component = set(axes) & set(sparse)
|
|
||||||
if has_sparse_component:
|
|
||||||
src += f" int r{currs}[{tile}] = 0 ... {tile};\n"
|
|
||||||
src += _einsum.unpack_cc(tile, axes, f'r', False)
|
|
||||||
else:
|
|
||||||
src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
|
|
||||||
src += _einsum.unpack_cc(tile, axes, f'r', False)
|
|
||||||
for pfx in prefixes:
|
|
||||||
for d in axes:
|
|
||||||
is_dense_dim = d not in sparse
|
|
||||||
is_dense_storage = (pfx == 'a' and not sparse_a) or\
|
|
||||||
(pfx == 'b' and not sparse_b) or\
|
|
||||||
(pfx == 'c' and not sparse_c)
|
|
||||||
if not is_dense_dim and is_dense_storage:
|
|
||||||
src += f" int {pfx}r{d}[{tile}] = off_{d} * BLOCK{d.upper()} + r{d};\n"
|
|
||||||
elif is_dense_dim and has_sparse_component:
|
|
||||||
src += f" int {pfx}r{d}[{tile}] = off_{d};\n"
|
|
||||||
else:
|
|
||||||
src += f" int {pfx}r{d}[{tile}] = r{d};\n"
|
|
||||||
|
|
||||||
src += f"""
|
|
||||||
// initialize pointers to A
|
|
||||||
int offa[TM, TK, TB] = {'inca' if sparse_a or sparse_b else '0'} """
|
|
||||||
for i, sym in enumerate(expr_a):
|
|
||||||
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b, 'a')
|
|
||||||
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}'
|
|
||||||
src += f" + ({ccode}) * {stride}\n "
|
|
||||||
src += ';'
|
|
||||||
|
|
||||||
src += """
|
|
||||||
TYPE *pa[TM, TK, TB] = A + offa;"""
|
|
||||||
|
|
||||||
|
|
||||||
if not sparse_a and not sparse_b and use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
|
||||||
spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
|
|
||||||
cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
|
|
||||||
src += f"""
|
|
||||||
int offadelta[TK] = off_k + 0 ... TK;
|
|
||||||
int {spec} *padelta[TK] = {cast}AD + offadelta;
|
|
||||||
int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];"""
|
|
||||||
|
|
||||||
src += f"""
|
|
||||||
|
|
||||||
// initialize pointers to B
|
|
||||||
int offb[TK, TN, TB] = {'incb' if sparse_a or sparse_b else '0'}"""
|
|
||||||
for i, sym in enumerate(expr_b):
|
|
||||||
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b, 'b')
|
|
||||||
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}'
|
|
||||||
src += f" + ({ccode}) * {stride}\n "
|
|
||||||
src += ';'
|
|
||||||
|
|
||||||
src += """
|
|
||||||
TYPE *pb[TK, TN, TB] = B + offb;"""
|
|
||||||
|
|
||||||
|
|
||||||
if not sparse_a and not sparse_b and use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
|
||||||
spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
|
|
||||||
cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
|
|
||||||
src += f"""
|
|
||||||
// initialize pointers to B look-up table
|
|
||||||
int offbdelta[TK] = off_k + 0 ... TK;
|
|
||||||
int *pbdelta[TK] = BD + offbdelta;"""
|
|
||||||
|
|
||||||
|
|
||||||
rk = 'r{}'.format(''.join(map(str,axes_k)))
|
|
||||||
src += f"""
|
|
||||||
|
|
||||||
// prefetch
|
|
||||||
int prefetch_k = select(rem_k > 0, rem_k, TK);
|
|
||||||
bool checkam[TM] = ar""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
|
|
||||||
bool checkbn[TN] = br""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
|
|
||||||
bool checkk[TK] = r{''.join(map(str, axes_k))} < prefetch_k;
|
|
||||||
bool checka[TM, TK, TB] = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
|
||||||
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
|
|
||||||
TYPE a[TM, TK, TB] = checka ? *pa : 0;
|
|
||||||
TYPE b[TK, TN, TB] = checkb ? *pb : 0;"""
|
|
||||||
|
|
||||||
if sparse_a:
|
|
||||||
src += f"""
|
|
||||||
// update pointers to look-up tables
|
|
||||||
pdelta += 2;
|
|
||||||
int incda = *(pdelta + 0);
|
|
||||||
int incdb = *(pdelta + 1);
|
|
||||||
pa += incda;
|
|
||||||
pb += incdb;"""
|
|
||||||
if sparse_b:
|
|
||||||
src += f"""
|
|
||||||
// update pointers to look-up tables
|
|
||||||
pdelta += 2;
|
|
||||||
int incdb = *(pdelta + 0);
|
|
||||||
int incda = *(pdelta + 1);
|
|
||||||
pa += incda;
|
|
||||||
pb += incdb;"""
|
|
||||||
|
|
||||||
if not sparse_a and not sparse_b and lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
|
||||||
src += """
|
|
||||||
pa += rem_delta_a;"""
|
|
||||||
elif not sparse_a and not sparse_b:
|
|
||||||
src += """
|
|
||||||
pa += incda;
|
|
||||||
padelta += TK;
|
|
||||||
incda = (*padelta)[newaxis, :, newaxis];"""
|
|
||||||
|
|
||||||
if not sparse_a and not sparse_b and lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
|
||||||
src += """
|
|
||||||
pb += rem_delta_b;"""
|
|
||||||
elif not sparse_a and not sparse_b:
|
|
||||||
src += """
|
|
||||||
pb += (*pbdelta)[:, newaxis, newaxis];
|
|
||||||
pbdelta += TK;"""
|
|
||||||
|
|
||||||
src += f"""
|
|
||||||
|
|
||||||
// accumulate
|
|
||||||
float acc[TM, TN, TB] = 0;
|
|
||||||
for(int k = matmul_k; k > 0; k -= TK) {{
|
|
||||||
acc += a @ b;
|
|
||||||
|
|
||||||
// load inputs
|
|
||||||
checkk = k > TK;
|
|
||||||
checka = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
|
||||||
checkb = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis];
|
|
||||||
a = *?(checka)pa;
|
|
||||||
b = *?(checkb)pb;
|
|
||||||
|
|
||||||
// update pointers"""
|
|
||||||
if sparse_a:
|
|
||||||
src += """
|
|
||||||
pdelta += 2;
|
|
||||||
incda = *(pdelta + 0);
|
|
||||||
incdb = *(pdelta + 1);
|
|
||||||
pa += incda;
|
|
||||||
pb += incdb;
|
|
||||||
"""
|
|
||||||
elif sparse_b:
|
|
||||||
src += """
|
|
||||||
pdelta += 2;
|
|
||||||
incdb = *(pdelta + 0);
|
|
||||||
incda = *(pdelta + 1);
|
|
||||||
pa += incda;
|
|
||||||
pb += incdb;
|
|
||||||
"""
|
|
||||||
else:
|
|
||||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
|
||||||
src += """
|
|
||||||
pa += stride_a_inner;"""
|
|
||||||
else:
|
|
||||||
src += """
|
|
||||||
pa += incda;
|
|
||||||
padelta += TK;
|
|
||||||
incda = (*padelta)[newaxis, :, newaxis];"""
|
|
||||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
|
||||||
src += """
|
|
||||||
pb += stride_b_inner;"""
|
|
||||||
else:
|
|
||||||
src += """
|
|
||||||
pb += (*pbdelta)[:, newaxis, newaxis];
|
|
||||||
pbdelta += TK;"""
|
|
||||||
|
|
||||||
src += f"""
|
|
||||||
}}
|
|
||||||
TYPE c[TM, TN, TB] = acc;
|
|
||||||
|
|
||||||
// initialize pointers to C
|
|
||||||
int offc[TM, TN, TB] = {'pid_0*TN*TN' if sparse_c else 0}"""
|
|
||||||
for i, sym in enumerate(expr_c):
|
|
||||||
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}'
|
|
||||||
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b, 'c')
|
|
||||||
src += f"\n + ({ccode}) * {stride}"
|
|
||||||
src += ';'
|
|
||||||
|
|
||||||
src += """
|
|
||||||
TYPE *pc[TM, TN, TB] = C + offc;
|
|
||||||
|
|
||||||
// bounds-checking
|
|
||||||
bool checkcm[TM] = cr""" + ''.join(map(str,axes_m)) + """ < matmul_m;
|
|
||||||
bool checkcn[TN] = cr""" + ''.join(map(str,axes_n)) + """ < matmul_n;
|
|
||||||
bool checkc[TM, TN, TB] = checkcm[:, newaxis, newaxis] &&
|
|
||||||
checkcn[newaxis, :, newaxis];
|
|
||||||
|
|
||||||
// write back
|
|
||||||
#if TZ == 1
|
|
||||||
*?(checkc)pc = c;
|
|
||||||
#else
|
|
||||||
int *plock = locks + pid_mn + pid_b * get_num_programs(0);
|
|
||||||
int *pcount = plock + 1024*1024;
|
|
||||||
// spin
|
|
||||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
|
||||||
int count = *pcount;
|
|
||||||
if(count == 0)
|
|
||||||
*?(checkc)pc = c;
|
|
||||||
else
|
|
||||||
*?(checkc)pc = c + *?(checkc)pc;
|
|
||||||
atomic_xchg(pcount, (count + 1) % (grid_z));
|
|
||||||
atomic_xchg(plock, 0);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# compilation options
|
|
||||||
TM, TN, TB, TZ = [32], [32], 1, [1]
|
|
||||||
TK = 16 if dtype==torch.float16 else 8
|
|
||||||
defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
|
||||||
for d, B in blocks.items():
|
|
||||||
defines[f'BLOCK{d}'] = B
|
|
||||||
# create kernel
|
|
||||||
ret = triton.kernel(src, defines=defines)
|
|
||||||
# set constant
|
|
||||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
|
||||||
ret.set_constant('AD', delta_a)
|
|
||||||
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
|
|
||||||
ret.set_constant('BD', delta_b)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
############################
|
|
||||||
## Look-up Table
|
|
||||||
############################
|
|
||||||
|
|
||||||
class LUT_MODE(IntEnum):
|
|
||||||
SCALAR = 1
|
|
||||||
CONSTANT = 2
|
|
||||||
DRAM = 3
|
|
||||||
|
|
||||||
def lut_mode(delta):
|
|
||||||
if delta.size == 0 or np.min(delta) == np.max(delta):
|
|
||||||
return _einsum.LUT_MODE.SCALAR
|
|
||||||
#if delta.size < 4096:
|
|
||||||
# return _einsum.LUT_MODE.CONSTANT
|
|
||||||
return _einsum.LUT_MODE.DRAM
|
|
||||||
|
|
||||||
def symbolic_delta(symbols, axes):
|
|
||||||
rank = len(symbols)
|
|
||||||
strides = [sp.symbols(f'stride{d}') for d in range(rank)]
|
|
||||||
nexts = {s: sp.symbols(f'next{s}') for s in axes}
|
|
||||||
delta = 0
|
|
||||||
for i in range(rank):
|
|
||||||
delta += strides[i] * (symbols[i].subs(nexts) - symbols[i])
|
|
||||||
return delta
|
|
||||||
|
|
||||||
def unpack_offset(k, axes, dims):
|
|
||||||
ret = dict()
|
|
||||||
for d in reversed(axes):
|
|
||||||
ret[d] = k % dims[d]
|
|
||||||
k = k // dims[d]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def make_dsd_delta(axes, step, stride, dims, symbols, sparse, layout, blocks):
|
|
||||||
# depth of reductions
|
|
||||||
depth = layout.sum(*[i for i, d in enumerate(sparse) if d in axes])
|
|
||||||
# outer dimension indices
|
|
||||||
outer = torch.nonzero(depth, as_tuple=False)
|
|
||||||
outer = [outer[:,i] for i in range(outer.shape[1])]
|
|
||||||
# find offset of outer dimensions
|
|
||||||
depth = depth.view(-1)
|
|
||||||
offsets = torch.zeros_like(depth)
|
|
||||||
offsets[1:] = torch.cumsum(depth[:-1], 0)
|
|
||||||
# compute delta for b
|
|
||||||
# TODO: support multiple sparse red indices
|
|
||||||
col = next((i for i, d in enumerate(sparse) if d in axes), None)
|
|
||||||
block = blocks[sparse[-1].upper()]
|
|
||||||
div = block // step
|
|
||||||
delta_b = torch.nonzero(layout.transpose(-1, col), as_tuple=False)[:, -1].reshape(-1).contiguous()
|
|
||||||
delta_b *= block
|
|
||||||
delta_b = [delta_b + step*i for i in range(div)]
|
|
||||||
delta_b = torch.stack(delta_b, dim=1)
|
|
||||||
delta_b = delta_b.view(-1)
|
|
||||||
# compute delta for a
|
|
||||||
bstride = 1
|
|
||||||
for d in sparse[::-1]:
|
|
||||||
if d in axes:
|
|
||||||
break
|
|
||||||
bstride *= blocks[d.upper()]
|
|
||||||
order = [d for d in sparse if d not in axes] +\
|
|
||||||
[d for d in sparse if d in axes]
|
|
||||||
idx = [sparse.index(d) for d in order]
|
|
||||||
layout[layout > 0] = 1 + torch.arange(layout.sum(), device=layout.device)
|
|
||||||
layout = layout.permute(*idx)
|
|
||||||
delta_a = layout[layout > 0] - 1
|
|
||||||
delta_a *= np.prod(list(blocks.values()))
|
|
||||||
saved = delta_a[offsets]
|
|
||||||
delta_a[1:] = delta_a[1:] - delta_a[:-1]
|
|
||||||
delta_a = delta_a.view(-1, 1).repeat(1, div)
|
|
||||||
delta_a[:, 1:] = step*bstride
|
|
||||||
delta_a[:, 0] -= (div - 1)*step*bstride
|
|
||||||
delta_a[offsets, 0] = saved
|
|
||||||
delta_a = delta_a.view(-1)
|
|
||||||
delta = torch.stack((delta_a, delta_b), dim=1).view(-1).contiguous()
|
|
||||||
# form look-up table
|
|
||||||
depth *= blocks[symbols[-1].upper()]
|
|
||||||
offsets *= div
|
|
||||||
header = torch.stack((offsets, depth, *outer), dim=1).view(-1).contiguous()
|
|
||||||
nouter = 2 + len(outer)
|
|
||||||
header[::nouter] = header[::nouter]*2 + header.shape[0]
|
|
||||||
lut = torch.cat((header, delta)).int().int().cpu().numpy()
|
|
||||||
return lut, nouter, _einsum.LUT_MODE.DRAM
|
|
||||||
|
|
||||||
def make_delta(axes, step, stride, dims, symbols, sparse, layout, lut = None, nouter = None):
|
|
||||||
# symbolic pointer increments
|
|
||||||
symbols = [sp.symbols(x) for x in symbols]
|
|
||||||
delta = _einsum.symbolic_delta(symbols, axes)
|
|
||||||
args = [f'stride{d}' for d in range(len(stride))]
|
|
||||||
args += [f'{sk}' for sk in axes]
|
|
||||||
args += [f'next{sk}' for sk in axes]
|
|
||||||
fn = sp.lambdify(args, delta, 'numpy')
|
|
||||||
if lut is None:
|
|
||||||
# inner axes values
|
|
||||||
inner = [dims[d] for d in axes]
|
|
||||||
inner = np.prod(inner)
|
|
||||||
rem = inner % step
|
|
||||||
rem = rem if rem > 0 else step
|
|
||||||
# k = [0, 1, ..., step, rem, rem + 1, ... rem + inner]
|
|
||||||
# nextk = [rem, 1 + rem, ..., step + rem, rem + step, rem + 1 + step, ..., inner + step]
|
|
||||||
k = np.concatenate((np.arange(step), np.arange(rem, inner))).astype(np.int32)
|
|
||||||
nextk = np.concatenate((k[:step] + rem, k[step:] + step))
|
|
||||||
else:
|
|
||||||
idx = (lut[:lut[0]:nouter] - lut[0])//2
|
|
||||||
k = lut[lut[0]+1::2]
|
|
||||||
k = np.insert(k, idx, 0)
|
|
||||||
nextk = k[1:]
|
|
||||||
k = k[:-1]
|
|
||||||
# offsets
|
|
||||||
off = _einsum.unpack_offset(k, axes, dims)
|
|
||||||
nextoff = _einsum.unpack_offset(nextk, axes, dims)
|
|
||||||
# evaluate deltas
|
|
||||||
args = [s for s in stride]
|
|
||||||
args += [off[sk] for sk in axes]
|
|
||||||
args += [nextoff[sk] for sk in axes]
|
|
||||||
delta = fn(*args)
|
|
||||||
delta = np.maximum(delta, 0)
|
|
||||||
if lut is not None:
|
|
||||||
idx = idx[1:] + np.arange(idx.shape[0] - 1)
|
|
||||||
delta = np.delete(delta, idx)
|
|
||||||
lut[lut[0]+1::2] = delta
|
|
||||||
return None, None
|
|
||||||
return delta, _einsum.lut_mode(delta[step:-step])
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def make_sdd_lut(layout_c, sparse_c, blocks):
|
|
||||||
nnz = torch.nonzero(layout_c, as_tuple=False)
|
|
||||||
lut = nnz.reshape(-1).int().cuda()
|
|
||||||
return lut
|
|
||||||
|
|
||||||
############################
|
|
||||||
## Compilation
|
|
||||||
############################
|
|
||||||
|
|
||||||
class instance:
|
|
||||||
|
|
||||||
locks = None
|
|
||||||
kernel_cache = dict()
|
|
||||||
|
|
||||||
def __init__(self, einsum, dtype, stride_a, stride_b, shape_a, shape_b, layouts, blocks):
|
|
||||||
# parse symbols
|
|
||||||
expr_a, expr_bc = einsum.split(",")
|
|
||||||
expr_b, expr_c = expr_bc.split("->")
|
|
||||||
sym_a = expr_a.lower()
|
|
||||||
sym_b = expr_b.lower()
|
|
||||||
sym_c = expr_c.lower()
|
|
||||||
sparse_a = [x.lower() for x in expr_a if x.isupper()]
|
|
||||||
sparse_b = [x.lower() for x in expr_b if x.isupper()]
|
|
||||||
sparse_c = [x.lower() for x in expr_c if x.isupper()]
|
|
||||||
layout_a = layouts.get(expr_a)
|
|
||||||
layout_b = layouts.get(expr_b)
|
|
||||||
layout_c = layouts.get(expr_c)
|
|
||||||
# parse axes
|
|
||||||
axes_b = [d for d in sym_a if d in sym_b and d in sym_c]
|
|
||||||
axes_m = [d for d in sym_a if d not in sym_b and d in sym_c]
|
|
||||||
axes_k = [d for d in sym_a if d in sym_b and d not in sym_c]
|
|
||||||
axes_n = [d for d in sym_b if d not in sym_a and d in sym_c]
|
|
||||||
axes = axes_b + axes_m + axes_n + axes_k
|
|
||||||
# check block sizes
|
|
||||||
for d in sparse_a + sparse_b + sparse_c:
|
|
||||||
if d.upper() not in blocks:
|
|
||||||
raise ValueError(f'unspecified block size for dimension: {d.upper()}')
|
|
||||||
# check layout is present
|
|
||||||
if sparse_a and layout_a is None:
|
|
||||||
raise ValueError('A is sparse but not layout provided')
|
|
||||||
if sparse_b and layout_b is None:
|
|
||||||
raise ValueError('B is sparse but not layout provided')
|
|
||||||
if sparse_c and layout_c is None:
|
|
||||||
raise ValueError('C is sparse but not layout provided')
|
|
||||||
# check dimensions
|
|
||||||
dims_a = dict([(x, y) for x,y in zip(sym_a, shape_a) if x not in sparse_a])
|
|
||||||
dims_b = dict([(x, y) for x,y in zip(sym_b, shape_b) if x not in sparse_b])
|
|
||||||
dims_La = None if layout_a is None else dict(zip([x for x in expr_a if x.isupper()], layout_a.shape))
|
|
||||||
dims_Lb = None if layout_b is None else dict(zip([x for x in expr_b if x.isupper()], layout_b.shape))
|
|
||||||
# TODO: could be cleaner
|
|
||||||
read_shape = lambda d, dimsT, dimsL, sparse: dimsL[d.upper()] * blocks[d.upper()] if d in sparse else dimsT[d]
|
|
||||||
for d in axes_b + axes_m + axes_n + axes_k:
|
|
||||||
dim_a = read_shape(d, dims_a, dims_La, sparse_a) if d in sym_a else None
|
|
||||||
dim_b = read_shape(d, dims_b, dims_Lb, sparse_b) if d in sym_b else None
|
|
||||||
if d in axes_b and dim_a and dim_b and dim_a != dim_b:
|
|
||||||
raise ValueError(f'incomparible batch dimension {d} (A: {dim_a}, B: {dim_b})')
|
|
||||||
if d in axes_k and dim_a and dim_b and dim_a != dim_b:
|
|
||||||
raise ValueError(f'incompatible inner dimension {d} (A: {dim_a}, B: {dim_b})')
|
|
||||||
dims = dict()
|
|
||||||
dims.update(dims_a)
|
|
||||||
dims.update(dims_b)
|
|
||||||
for i, d in enumerate(sparse_a):
|
|
||||||
dims[d] = layout_a.shape[i] * blocks[d.upper()]
|
|
||||||
for i, d in enumerate(sparse_b):
|
|
||||||
dims[d] = layout_b.shape[i] * blocks[d.upper()]
|
|
||||||
# allocate output
|
|
||||||
shape_c = [dims[d] if d.islower() else blocks[d] for d in expr_c]
|
|
||||||
if sparse_c:
|
|
||||||
shape_c.insert(expr_c.index(sparse_c[0].upper()), int(layout_c.sum()))
|
|
||||||
stride_c = [None] * len(shape_c)
|
|
||||||
stride_c[-1] = 1
|
|
||||||
for i in reversed(range(len(shape_c) - 1)):
|
|
||||||
stride_c[i] = stride_c[i+1] * shape_c[i+1]
|
|
||||||
# look-up tables
|
|
||||||
TK = 16 if dtype == torch.float16 else 8
|
|
||||||
if sparse_a and not sparse_b:
|
|
||||||
delta_a, nouter, lut_mode_a = _einsum.make_dsd_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, blocks)
|
|
||||||
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, delta_a, nouter)
|
|
||||||
if sparse_b and not sparse_a:
|
|
||||||
delta_b, nouter, lut_mode_b = _einsum.make_dsd_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, blocks)
|
|
||||||
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, delta_b, nouter)
|
|
||||||
if not sparse_a and not sparse_b:
|
|
||||||
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a)
|
|
||||||
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b)
|
|
||||||
if sparse_c:
|
|
||||||
delta_c = _einsum.make_sdd_lut(layout_c, sparse_c, blocks)
|
|
||||||
# hash for recompilation
|
|
||||||
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
|
|
||||||
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
|
|
||||||
stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
|
|
||||||
stride_a_last = stride_a[-1]
|
|
||||||
stride_b_last = stride_b[-1]
|
|
||||||
stride_c_last = stride_c[-1]
|
|
||||||
name = f'{dtype}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
|
|
||||||
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\
|
|
||||||
f'_{stride_a_last}_{stride_b_last}_{stride_c_last}'
|
|
||||||
# recompile if necessary
|
|
||||||
cache = _einsum.instance.kernel_cache
|
|
||||||
if name not in cache:
|
|
||||||
cachesize = len(cache)
|
|
||||||
cache[name] = _einsum.make_kernel(f'__einsum{cachesize}',
|
|
||||||
dtype,
|
|
||||||
sym_a, sym_b, sym_c,
|
|
||||||
sparse_a, sparse_b, sparse_c,
|
|
||||||
axes_m, axes_n, axes_k, axes_b,
|
|
||||||
stride_a_multiple, stride_b_multiple, stride_c_multiple,
|
|
||||||
stride_a_last, stride_b_last, stride_c_last,
|
|
||||||
lut_mode_a, lut_mode_b,
|
|
||||||
delta_a, delta_b,
|
|
||||||
blocks)
|
|
||||||
self.kernel = cache[name]
|
|
||||||
# Initialize locks
|
|
||||||
if _einsum.instance.locks is None:
|
|
||||||
_einsum.instance.locks = torch.zeros(2*1024*1024, dtype=torch.int32).cuda()
|
|
||||||
# Kernel arguments
|
|
||||||
dim_m = [dims[d] for d in axes_m]
|
|
||||||
dim_n = [dims[d] for d in axes_n]
|
|
||||||
dim_k = [dims[d] for d in axes_k]
|
|
||||||
dim_b = [dims[d] for d in axes_b]
|
|
||||||
M = reduce(mul, dim_m, 1)
|
|
||||||
N = reduce(mul, dim_n, 1)
|
|
||||||
K = reduce(mul, dim_k, 1)
|
|
||||||
B = reduce(mul, [dims[d] for d in axes_b if d.upper() not in einsum], 1)
|
|
||||||
stride_a = list(stride_a[:-1])
|
|
||||||
stride_b = list(stride_b[:-1])
|
|
||||||
stride_c = list(stride_c[:-1])
|
|
||||||
alpha = 1.
|
|
||||||
div_m = 1
|
|
||||||
self.args = [None, None, None]
|
|
||||||
self.args += [_einsum.instance.locks]
|
|
||||||
self.args += [alpha, M, N, K, div_m]
|
|
||||||
self.args += dim_m
|
|
||||||
self.args += dim_n
|
|
||||||
self.args += dim_k
|
|
||||||
self.args += dim_b
|
|
||||||
self.args += stride_a
|
|
||||||
self.args += stride_b
|
|
||||||
self.args += stride_c
|
|
||||||
# LUT for A
|
|
||||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
|
||||||
self.args += [delta_a[TK], delta_a[0]]
|
|
||||||
elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM:
|
|
||||||
self.args += [torch.from_numpy(delta_a).cuda()]
|
|
||||||
# LUT for B
|
|
||||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
|
||||||
self.args += [delta_b[TK], delta_b[0]]
|
|
||||||
elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM:
|
|
||||||
self.args += [torch.from_numpy(delta_b).cuda()]
|
|
||||||
# LUT for C
|
|
||||||
if sparse_c:
|
|
||||||
self.args += [delta_c]
|
|
||||||
if sparse_a or sparse_b:
|
|
||||||
width = delta_a[0] // nouter if sparse_a else delta_b[0] // nouter
|
|
||||||
self.args += [width]
|
|
||||||
# Grid
|
|
||||||
if sparse_a:
|
|
||||||
self.grid = lambda opt: [width*triton.cdiv(N, opt.d('TN')), B, opt.d('TZ')]
|
|
||||||
elif sparse_b:
|
|
||||||
self.grid = lambda opt: [width*triton.cdiv(M, opt.d('TM')), B, opt.d('TZ')]
|
|
||||||
elif sparse_c:
|
|
||||||
width = int(layout_c.sum())
|
|
||||||
self.grid = lambda opt: [width, B, opt.d('TZ')]
|
|
||||||
else:
|
|
||||||
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
|
|
||||||
triton.cdiv(N, opt.d('TN')),
|
|
||||||
triton.cdiv(B, opt.d('TB')),
|
|
||||||
opt.d('TZ')]
|
|
||||||
# position of dynamic arguments
|
|
||||||
self.pos_a = 0
|
|
||||||
self.pos_b = 1
|
|
||||||
self.pos_c = 2
|
|
||||||
# save information on the operation
|
|
||||||
self.expr_a = expr_a
|
|
||||||
self.expr_b = expr_b
|
|
||||||
self.expr_c = expr_c
|
|
||||||
self.matmul_B = B
|
|
||||||
self.matmul_M = M
|
|
||||||
self.matmul_N = N
|
|
||||||
self.matmul_K = K
|
|
||||||
# output shape
|
|
||||||
self.shape_c = shape_c
|
|
||||||
|
|
||||||
|
|
||||||
def run(self, a, b):
|
|
||||||
c = torch.empty(*self.shape_c, dtype=a.dtype, device=a.device)
|
|
||||||
self.args[self.pos_a] = a
|
|
||||||
self.args[self.pos_b] = b
|
|
||||||
self.args[self.pos_c] = c
|
|
||||||
self.kernel(*self.args, grid=self.grid)
|
|
||||||
return c
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
|
||||||
## Forward
|
|
||||||
############################
|
|
||||||
|
|
||||||
instance_cache = dict()
|
|
||||||
registry = dict()
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, expr, a, b, layouts, blocks):
|
|
||||||
# compile einsum instance
|
|
||||||
cache = _einsum.instance_cache
|
|
||||||
key = (expr, a.dtype,
|
|
||||||
a.stride(), b.stride(),
|
|
||||||
a.shape , b.shape)
|
|
||||||
if key not in cache:
|
|
||||||
cache[key] = _einsum.instance(expr, a.dtype,
|
|
||||||
a.stride(), b.stride(),
|
|
||||||
a.shape , b.shape ,
|
|
||||||
layouts, blocks)
|
|
||||||
instance = cache[key]
|
|
||||||
# run and mark as dirty c modified in-place
|
|
||||||
c = instance.run(a, b)
|
|
||||||
# save information in context
|
|
||||||
ctx.expr_a = instance.expr_a
|
|
||||||
ctx.expr_b = instance.expr_b
|
|
||||||
ctx.expr_c = instance.expr_c
|
|
||||||
ctx.matmul_B = instance.matmul_B
|
|
||||||
ctx.matmul_M = instance.matmul_M
|
|
||||||
ctx.matmul_N = instance.matmul_N
|
|
||||||
ctx.matmul_K = instance.matmul_K
|
|
||||||
ctx.save_for_backward(a, b)
|
|
||||||
return c
|
|
||||||
|
|
||||||
|
|
||||||
############################
|
|
||||||
## Backward
|
|
||||||
############################
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, dy):
|
|
||||||
a, b = ctx.saved_tensors
|
|
||||||
expr_a = ctx.expr_a
|
|
||||||
expr_b = ctx.expr_b
|
|
||||||
expr_c = ctx.expr_c
|
|
||||||
# gradient of first argument
|
|
||||||
da = None
|
|
||||||
if ctx.needs_input_grad[1]:
|
|
||||||
da = torch.empty_like(a)
|
|
||||||
einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, da)
|
|
||||||
# gradient of second argument
|
|
||||||
db = None
|
|
||||||
if ctx.needs_input_grad[2]:
|
|
||||||
db = torch.empty_like(b)
|
|
||||||
einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db)
|
|
||||||
return None, da, db, None, None, None, None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
def einsum(expr, a, b, layouts = None, blocks = dict()):
|
|
||||||
return _einsum.apply(expr, a, b, layouts, blocks)
|
|
@@ -21,7 +21,7 @@ int main() {
|
|||||||
// config_t{ord, x[0], x[1], 1280, 1280, 1280},
|
// config_t{ord, x[0], x[1], 1280, 1280, 1280},
|
||||||
// config_t{ord, x[0], x[1], 1536, 1536, 1536},
|
// config_t{ord, x[0], x[1], 1536, 1536, 1536},
|
||||||
// config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
// config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
||||||
config_t{ord, x[0], x[1], 4096, 4096, 4096},
|
config_t{ord, x[0], x[1], 8192, 8192, 8192},
|
||||||
|
|
||||||
// config_t{ord, x[0], x[1], 256, 16, 256},
|
// config_t{ord, x[0], x[1], 256, 16, 256},
|
||||||
// config_t{ord, x[0], x[1], 512, 16, 512},
|
// config_t{ord, x[0], x[1], 512, 16, 512},
|
||||||
|
@@ -102,7 +102,7 @@ void triton_conv(drv::context* context, drv::stream* stream,
|
|||||||
stream->write(&*ddelta, true, 0, hdelta);
|
stream->write(&*ddelta, true, 0, hdelta);
|
||||||
|
|
||||||
// macros
|
// macros
|
||||||
rt::function::options_space_t opt;
|
rt::options_space_t opt;
|
||||||
opt.defines.push_back({"TYPE", {ty}});
|
opt.defines.push_back({"TYPE", {ty}});
|
||||||
opt.defines.push_back({"TM", {"128"}});
|
opt.defines.push_back({"TM", {"128"}});
|
||||||
opt.defines.push_back({"TN", {"128"}});
|
opt.defines.push_back({"TN", {"128"}});
|
||||||
@@ -125,7 +125,7 @@ void triton_conv(drv::context* context, drv::stream* stream,
|
|||||||
W*H*CI, W*H, W, 1,
|
W*H*CI, W*H, W, 1,
|
||||||
CO*S*R , CO*S, CO, 1,
|
CO*S*R , CO*S, CO, 1,
|
||||||
Q*P*CO, Q*P, Q, 1};
|
Q*P*CO, Q*P, Q, 1};
|
||||||
auto grid = [Z,P,Q,CO](const rt::function::options_t& x) {
|
auto grid = [Z,P,Q,CO](const rt::options_t& x) {
|
||||||
return rt::grid_t{ceil(Z*P*Q, x.D<int>("TM")),
|
return rt::grid_t{ceil(Z*P*Q, x.D<int>("TM")),
|
||||||
ceil(CO , x.D<int>("TN")),
|
ceil(CO , x.D<int>("TN")),
|
||||||
(size_t)x.D<int>("TZ")};
|
(size_t)x.D<int>("TZ")};
|
||||||
|
@@ -107,7 +107,7 @@ void triton_copy_nd(drv::context* context, drv::stream* stream, const std::vecto
|
|||||||
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
|
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
|
||||||
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
|
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
|
||||||
// create options
|
// create options
|
||||||
rt::function::options_space_t opt;
|
rt::options_space_t opt;
|
||||||
|
|
||||||
|
|
||||||
// macros
|
// macros
|
||||||
|
@@ -101,46 +101,38 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
|
|||||||
// ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
|
// ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
|
||||||
|
|
||||||
// macros
|
// macros
|
||||||
rt::function::options_space_t opt;
|
rt::options_space_t opts;
|
||||||
// A access patterns
|
// A access patterns
|
||||||
opt.defines.push_back({"USEA", {AT? "a" : "a" }});
|
opts.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
|
||||||
opt.defines.push_back({"BROADCAST_AK", {AT? "newaxis, :" : "newaxis, :" }});
|
opts.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
|
||||||
opt.defines.push_back({"BROADCAST_AM", {AT? ":, newaxis" : ":, newaxis" }});
|
|
||||||
opt.defines.push_back({"SHAPE_A", {AT? "TM, TK" : "TM, TK" }});
|
|
||||||
opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
|
|
||||||
opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
|
|
||||||
// B access patterns
|
// B access patterns
|
||||||
opt.defines.push_back({"USEB", {BT? "b" : "b" }});
|
opts.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }});
|
||||||
opt.defines.push_back({"BROADCAST_BK", {BT? ":, newaxis" : ":, newaxis" }});
|
opts.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }});
|
||||||
opt.defines.push_back({"BROADCAST_BN", {BT? "newaxis, :" : "newaxis, :" }});
|
|
||||||
opt.defines.push_back({"SHAPE_B", {BT? "TK, TN" : "TK, TN" }});
|
|
||||||
opt.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }});
|
|
||||||
opt.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }});
|
|
||||||
// data-type
|
// data-type
|
||||||
opt.defines.push_back({"TYPE", {ty}});
|
opts.defines.push_back({"TYPE", {ty}});
|
||||||
// tile sizes
|
// tile sizes
|
||||||
if(mode == TEST) {
|
if(mode == TEST) {
|
||||||
opt.defines.push_back({"TM", {std::to_string(TM)}});
|
opts.defines.push_back({"TM", {std::to_string(TM)}});
|
||||||
opt.defines.push_back({"TN", {std::to_string(TN)}});
|
opts.defines.push_back({"TN", {std::to_string(TN)}});
|
||||||
opt.defines.push_back({"TK", {std::to_string(TK)}});
|
opts.defines.push_back({"TK", {std::to_string(TK)}});
|
||||||
opt.defines.push_back({"TZ", {"1"}});
|
opts.defines.push_back({"TZ", {"1"}});
|
||||||
opt.num_warps = {nwarp};
|
opts.num_warps = {nwarp};
|
||||||
}
|
}
|
||||||
if(mode == BENCH) {
|
if(mode == BENCH) {
|
||||||
opt.defines.push_back({"TM", {"64", "128"}});
|
opts.defines.push_back({"TM", {"128"}});
|
||||||
opt.defines.push_back({"TN", {"64", "128"}});
|
opts.defines.push_back({"TN", {"128"}});
|
||||||
opt.defines.push_back({"TK", {"16"}});
|
opts.defines.push_back({"TK", {"32"}});
|
||||||
opt.defines.push_back({"TZ", {"1"}});
|
opts.defines.push_back({"TZ", {"1"}});
|
||||||
opt.num_warps = {4};
|
opts.num_warps = {4};
|
||||||
}
|
}
|
||||||
|
|
||||||
// kernels
|
// kernels
|
||||||
rt::function function(src::dot, opt);
|
rt::function function(src::dot, opts);
|
||||||
dot_arg_t args = {da->addr_as_uintptr_t(), db->addr_as_uintptr_t(), dc->addr_as_uintptr_t(),
|
dot_arg_t args = {da->addr_as_uintptr_t(), db->addr_as_uintptr_t(), dc->addr_as_uintptr_t(),
|
||||||
1, M, N, K, lda, ldb, ldc, dlocks->addr_as_uintptr_t()};
|
1, M, N, K, lda, ldb, ldc, dlocks->addr_as_uintptr_t()};
|
||||||
|
|
||||||
auto grid = [M, N](const rt::function::options_t& x) {
|
auto grid = [M, N](const rt::options_t& x) {
|
||||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
return rt::grid_t{ceil(M, x.D<int>("TM"))*
|
||||||
ceil(N, x.D<int>("TN")),
|
ceil(N, x.D<int>("TN")),
|
||||||
(size_t)x.D<int>("TZ")};
|
(size_t)x.D<int>("TZ")};
|
||||||
};
|
};
|
||||||
@@ -151,19 +143,25 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
|
|||||||
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream);
|
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream);
|
||||||
bench.push_back(tflops(triton_ns));
|
bench.push_back(tflops(triton_ns));
|
||||||
|
|
||||||
// // cublas
|
// cublas
|
||||||
// if(cublas::cublasinit()){
|
if(cublas::cublasinit()){
|
||||||
// T alpha(static_cast<double>(1));
|
T alpha(static_cast<double>(1));
|
||||||
// T beta(static_cast<double>(0));
|
T beta(static_cast<double>(0));
|
||||||
// cublasGemmAlgo_t fastest;
|
cublasGemmAlgo_t fastest;
|
||||||
// cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
// cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||||
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
|
double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, !AT, !BT, M, N, K,
|
||||||
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
&alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
||||||
// ldc, nullptr, fastest); }, stream);
|
ldc); }, stream);
|
||||||
// bench.push_back(tflops(cublas_ms));
|
bench.push_back(tflops(cublas_ms));
|
||||||
// }
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rt::options_t opt;
|
||||||
|
// for(auto &x: opts.defines)
|
||||||
|
// opt.defines[x.first] = x.second[0];
|
||||||
|
// opt.num_warps = 1;
|
||||||
|
// std::cout << function.get_asm(rt::ASM_NV_PTX, device, opt) << std::endl;
|
||||||
|
|
||||||
// test triton
|
// test triton
|
||||||
if(mode == TEST){
|
if(mode == TEST){
|
||||||
srand(0);
|
srand(0);
|
||||||
|
@@ -95,25 +95,25 @@ void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vec
|
|||||||
y_strides.push_back(y_strides[d] + " * " + y_shapename[y_order[d]]);
|
y_strides.push_back(y_strides[d] + " * " + y_shapename[y_order[d]]);
|
||||||
|
|
||||||
// options
|
// options
|
||||||
rt::function::options_space_t opt;
|
rt::options_space_t opts;
|
||||||
opt.defines.push_back({"TYPE", {ty}});
|
opts.defines.push_back({"TYPE", {ty}});
|
||||||
for(int d = 0; d < rank_x; d++)
|
for(int d = 0; d < rank_x; d++)
|
||||||
opt.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}});
|
opts.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}});
|
||||||
for(int d = 0; d < rank_y; d++)
|
for(int d = 0; d < rank_y; d++)
|
||||||
opt.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}});
|
opts.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}});
|
||||||
if(TS.empty())
|
if(TS.empty())
|
||||||
TS = tile_nd(rank_x);
|
TS = tile_nd(rank_x);
|
||||||
for(int d = 0; d < rank_x; d++)
|
for(int d = 0; d < rank_x; d++)
|
||||||
opt.defines.push_back({"TS" + std::to_string(d), TS[d]});
|
opts.defines.push_back({"TS" + std::to_string(d), TS[d]});
|
||||||
|
|
||||||
std::vector<size_t> axy;
|
std::vector<size_t> axy;
|
||||||
for(int d = 0; d < rank_x; d++)
|
for(int d = 0; d < rank_x; d++)
|
||||||
if(d != axis)
|
if(d != axis)
|
||||||
axy.push_back(d);
|
axy.push_back(d);
|
||||||
for(int d = 0; d < rank_y; d++)
|
for(int d = 0; d < rank_y; d++)
|
||||||
opt.defines.push_back({"TY" + std::to_string(d), {std::to_string(shape_x[axy[d]])}});
|
opts.defines.push_back({"TY" + std::to_string(d), {std::to_string(shape_x[axy[d]])}});
|
||||||
for(int d = 0; d < rank_y; d++)
|
for(int d = 0; d < rank_y; d++)
|
||||||
opt.defines.push_back({"RY" + std::to_string(d), {"rs" + std::to_string(axy[d])}});
|
opts.defines.push_back({"RY" + std::to_string(d), {"rs" + std::to_string(axy[d])}});
|
||||||
|
|
||||||
std::string RED = "";
|
std::string RED = "";
|
||||||
for(int n = 0; n < rank_x; n++){
|
for(int n = 0; n < rank_x; n++){
|
||||||
@@ -121,30 +121,40 @@ void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vec
|
|||||||
RED += ", ";
|
RED += ", ";
|
||||||
RED += (n==axis) ? to_str(op) : ":";
|
RED += (n==axis) ? to_str(op) : ":";
|
||||||
}
|
}
|
||||||
opt.defines.push_back({"RED", {RED}});
|
opts.defines.push_back({"RED", {RED}});
|
||||||
opt.num_warps = {1};
|
opts.num_warps = {2};
|
||||||
|
|
||||||
// kernel
|
// kernel
|
||||||
rt::function function(src::reduce_nd[rank_x - 1], opt);
|
rt::function function(src::reduce_nd[rank_x - 1], opts);
|
||||||
|
|
||||||
// input buffers
|
// input buffers
|
||||||
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_x*dtsize));
|
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_x*dtsize));
|
||||||
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_y*dtsize));
|
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size_y*dtsize));
|
||||||
|
|
||||||
// grid
|
// grid
|
||||||
reduce_arg_t args = {*dx->cu(), *dy->cu(), shape_x[0]};
|
std::stringstream oss;
|
||||||
if(shape_x.size() > 1) args.S1 = shape_x[1];
|
rt::add_arg(oss, *dx->cu());
|
||||||
if(shape_x.size() > 2) args.S2 = shape_x[2];
|
rt::add_arg(oss, *dy->cu());
|
||||||
|
rt::add_arg(oss, (uint32_t)shape_x[0]);
|
||||||
|
if(shape_x.size() > 1) rt::add_arg(oss, (uint32_t)shape_x[1]);
|
||||||
|
if(shape_x.size() > 2) rt::add_arg(oss, (uint32_t)shape_x[2]);
|
||||||
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
|
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
|
||||||
auto grid = grid_nd(shape_x, ts);
|
auto grid = grid_nd(shape_x, ts);
|
||||||
|
|
||||||
// metrics
|
// metrics
|
||||||
if(mode == BENCH){
|
if(mode == BENCH){
|
||||||
auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; };
|
auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; };
|
||||||
double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream);
|
double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream);
|
||||||
bench.push_back(gbps(triton_ns));
|
bench.push_back(gbps(triton_ns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rt::options_t opt;
|
||||||
|
// for(auto &x: opts.defines)
|
||||||
|
// opt.defines[x.first] = x.second[0];
|
||||||
|
// opt.num_warps = 1;
|
||||||
|
// std::cout << function.get_asm(rt::ASM_NV_PTX, device, opt) << std::endl;
|
||||||
|
|
||||||
|
|
||||||
// test triton
|
// test triton
|
||||||
if(mode == TEST){
|
if(mode == TEST){
|
||||||
std::vector<NumericT> hy(size_y);
|
std::vector<NumericT> hy(size_y);
|
||||||
@@ -153,7 +163,7 @@ void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vec
|
|||||||
init_zeros(hy);
|
init_zeros(hy);
|
||||||
init_rand(hx);
|
init_rand(hx);
|
||||||
stream->write(&*dx, true, 0, hx);
|
stream->write(&*dx, true, 0, hx);
|
||||||
function((void**)&args, sizeof(args), grid, stream, device);
|
function((void**)oss.str().data(), oss.str().size(), grid, stream, device);
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
stream->read(&*dy, true, 0, hy);
|
stream->read(&*dy, true, 0, hy);
|
||||||
cc_reduce_nd(ry, hx, op, axis, shape_x);
|
cc_reduce_nd(ry, hx, op, axis, shape_x);
|
||||||
|
@@ -2,6 +2,9 @@ namespace src {
|
|||||||
|
|
||||||
const char *dot =
|
const char *dot =
|
||||||
R"(
|
R"(
|
||||||
|
#define STM 8
|
||||||
|
#define STN 8
|
||||||
|
|
||||||
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||||
TYPE * B __noalias __readonly __aligned(16),
|
TYPE * B __noalias __readonly __aligned(16),
|
||||||
TYPE * C __noalias __aligned(16),
|
TYPE * C __noalias __aligned(16),
|
||||||
@@ -14,54 +17,59 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
int ldc __multipleof(8),
|
int ldc __multipleof(8),
|
||||||
int* locks) {
|
int* locks) {
|
||||||
// prologue
|
// prologue
|
||||||
int ridx = get_program_id(0);
|
int pid = get_program_id(0);
|
||||||
int ridy = get_program_id(1);
|
int pidz = get_program_id(2);
|
||||||
int ridz = get_program_id(2);
|
int gridm = (M + TM - 1) / TM;
|
||||||
int gridx = M / TM;
|
int gridn = (N + TN - 1) / TN;
|
||||||
int gridy = N / TN;
|
int width = STM*gridn;
|
||||||
int rid = ridx + ridy * gridx;
|
int stm = pid / width;
|
||||||
ridx = rid / gridy;
|
int RSTM = min(gridm - stm*STM, STM);
|
||||||
ridy = rid % gridy;
|
int stn = (pid % width) / (RSTM*STN);
|
||||||
int rm[TM] = ridx * TM + 0 ... TM;
|
int RSTN = min(gridn - stn*STN, STN);
|
||||||
int rn[TN] = ridy * TN + 0 ... TN;
|
int laneid = pid % (RSTM * RSTN);
|
||||||
|
int lanem = laneid / RSTN;
|
||||||
|
int lanen = laneid % RSTN;
|
||||||
|
int pidm = stm*STM + lanem;
|
||||||
|
int pidn = stn*STN + lanen;
|
||||||
|
int rm[TM] = pidm * TM + 0 ... TM;
|
||||||
|
int rn[TN] = pidn * TN + 0 ... TN;
|
||||||
|
|
||||||
// reduction splitting
|
// reduction splitting
|
||||||
K = K / TZ;
|
K = K / TZ;
|
||||||
int rk[TK] = ridz * K + 0 ... TK;
|
int rk[TK] = pidz * K + 0 ... TK;
|
||||||
|
|
||||||
// pointers to operands
|
// pointers to operands
|
||||||
int offa[SHAPE_A] = rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||||
int offb[SHAPE_B] = rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
|
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
|
||||||
TYPE* pa[SHAPE_A] = A + offa;
|
TYPE* pa[TM, TK] = A + offa;
|
||||||
TYPE* pb[SHAPE_B] = B + offb;
|
TYPE* pb[TK, TN] = B + offb;
|
||||||
|
|
||||||
// prefetches operands
|
// prefetches operands
|
||||||
bool checka[SHAPE_A] = rk[BROADCAST_AK] < K;
|
bool checka[TM, TK] = rk[newaxis, :] < K;
|
||||||
bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K;
|
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||||
TYPE a[SHAPE_A] = checka ? *pa : 0;
|
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||||
TYPE b[SHAPE_B] = checkb ? *pb : 0;
|
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||||
|
pa += TK * STRIDE_AK;
|
||||||
|
pb += TK * STRIDE_BK;
|
||||||
// reduction loop
|
// reduction loop
|
||||||
float acc[TM, TN] = 0;
|
float acc[TM, TN] = 0;
|
||||||
for(int k = K; k > 0; k -= TK){
|
for(int k = K; k > 0; k -= TK){
|
||||||
acc += USEA @ USEB;
|
bool checka[TM, TK] = k > TK;
|
||||||
bool checka[SHAPE_A] = k > TK;
|
bool checkb[TK, TN] = k > TK;
|
||||||
bool checkb[SHAPE_B] = k > TK;
|
acc += a @ b;
|
||||||
pa += TK * STRIDE_AK;
|
|
||||||
pb += TK * STRIDE_BK;
|
|
||||||
a = *?(checka)pa;
|
a = *?(checka)pa;
|
||||||
b = *?(checkb)pb;
|
b = *?(checkb)pb;
|
||||||
|
pa += TK * STRIDE_AK;
|
||||||
|
pb += TK * STRIDE_BK;
|
||||||
}
|
}
|
||||||
acc = acc * alpha;
|
acc = acc * alpha;
|
||||||
TYPE c[TM, TN] = acc;
|
TYPE c[TM, TN] = acc;
|
||||||
|
|
||||||
// epilogue
|
// epilogue
|
||||||
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
|
int rcm[TM] = pidm * TM + 0 ... TM;
|
||||||
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
|
int rcn[TN] = pidn * TN + 0 ... TN;
|
||||||
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
|
||||||
TYPE* pc[TM, TN] = C + offc;
|
TYPE* pc[TM, TN] = C + offc;
|
||||||
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
bool checkc[TM, TN] = rcm[:, newaxis] < M &&
|
||||||
|
rcn[newaxis, :] < N;
|
||||||
#if (TZ==1)
|
#if (TZ==1)
|
||||||
*?(checkc) pc = c;
|
*?(checkc) pc = c;
|
||||||
#else
|
#else
|
||||||
|
@@ -19,13 +19,13 @@ inline size_t ceil(size_t x, size_t y) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline rt::function::grid_fn_ty grid1d(size_t N) {
|
inline rt::function::grid_fn_ty grid1d(size_t N) {
|
||||||
return [N](const rt::function::options_t& x) {
|
return [N](const rt::options_t& x) {
|
||||||
return rt::grid_t{ceil(N, x.D<int>("TN"))};
|
return rt::grid_t{ceil(N, x.D<int>("TN"))};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
||||||
return [M, N](const rt::function::options_t& x) {
|
return [M, N](const rt::options_t& x) {
|
||||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||||
ceil(N, x.D<int>("TN"))};
|
ceil(N, x.D<int>("TN"))};
|
||||||
};
|
};
|
||||||
@@ -33,7 +33,7 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
|||||||
|
|
||||||
inline rt::function::grid_fn_ty grid_nd(const std::vector<int32_t> &shape,
|
inline rt::function::grid_fn_ty grid_nd(const std::vector<int32_t> &shape,
|
||||||
const std::vector<std::string>& ts) {
|
const std::vector<std::string>& ts) {
|
||||||
return [&shape, &ts](const rt::function::options_t& x) {
|
return [&shape, &ts](const rt::options_t& x) {
|
||||||
rt::grid_t ret;
|
rt::grid_t ret;
|
||||||
for(size_t d = 0; d < shape.size(); d++)
|
for(size_t d = 0; d < shape.size(); d++)
|
||||||
ret.push_back(ceil(shape[d], x.D<int>(ts[d])));
|
ret.push_back(ceil(shape[d], x.D<int>(ts[d])));
|
||||||
@@ -71,6 +71,12 @@ void init_zeros(std::vector<T>& x) {
|
|||||||
x[i] = 0;
|
x[i] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
void init_ones(std::vector<T>& x) {
|
||||||
|
for(size_t i = 0; i < x.size(); i++)
|
||||||
|
x[i] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
/* ------------------------
|
/* ------------------------
|
||||||
* Loop Nests
|
* Loop Nests
|
||||||
* ------------------------ */
|
* ------------------------ */
|
||||||
@@ -163,7 +169,7 @@ for(size_t i = 0; i < hc.size(); i++)
|
|||||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -8,22 +8,58 @@ int main() {
|
|||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
|
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
|
||||||
// shapes to test
|
// shapes to test
|
||||||
typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t;
|
typedef std::tuple<dtype_t, bool, bool, int, int, int, int> config_t;
|
||||||
std::vector<config_t> configs;
|
std::vector<config_t> configs;
|
||||||
for(int TM: std::vector<int>{32, 64, 128})
|
for(dtype_t dtype: std::vector<dtype_t>{FLOAT, HALF})
|
||||||
for(int TN: std::vector<int>{32, 64, 128})
|
for(bool AT: std::vector<bool>{false, true})
|
||||||
for(int TK: std::vector<int>{16})
|
for(bool BT: std::vector<bool>{false, true}){
|
||||||
for(int nwarps: std::vector<int>{4})
|
// 1 warp
|
||||||
for(bool AT: std::array<bool, 2>{false, true})
|
configs.push_back({dtype, AT, BT, 16, 16, 16, 1});
|
||||||
for(bool BT: std::array<bool, 2>{false, true}){
|
configs.push_back({dtype, AT, BT, 32, 16, 16, 1});
|
||||||
configs.push_back(config_t{FLOAT, AT, BT, TM, TN, TK, TM, TN, TK, nwarps});
|
configs.push_back({dtype, AT, BT, 16, 32, 16, 1});
|
||||||
}
|
configs.push_back({dtype, AT, BT, 16, 16, 32, 1});
|
||||||
|
configs.push_back({dtype, AT, BT, 32, 16, 32, 1});
|
||||||
|
configs.push_back({dtype, AT, BT, 16, 32, 32, 1});
|
||||||
|
if(dtype == HALF){
|
||||||
|
configs.push_back({dtype, AT, BT, 16, 64, 16, 1});
|
||||||
|
configs.push_back({dtype, AT, BT, 16, 16, 64, 1});
|
||||||
|
configs.push_back({dtype, AT, BT, 64, 16, 64, 1});
|
||||||
|
configs.push_back({dtype, AT, BT, 16, 64, 64, 1});
|
||||||
|
}
|
||||||
|
// 2 warps
|
||||||
|
configs.push_back({dtype, AT, BT, 64, 32, 64, 2});
|
||||||
|
configs.push_back({dtype, AT, BT, 32, 64, 64, 2});
|
||||||
|
configs.push_back({dtype, AT, BT, 64, 32, 16, 2});
|
||||||
|
configs.push_back({dtype, AT, BT, 32, 64, 16, 2});
|
||||||
|
configs.push_back({dtype, AT, BT, 128, 32, 32, 2});
|
||||||
|
configs.push_back({dtype, AT, BT, 32, 128, 32, 2});
|
||||||
|
// 4 warps
|
||||||
|
configs.push_back({dtype, AT, BT, 128, 64, 16, 4});
|
||||||
|
configs.push_back({dtype, AT, BT, 64, 128, 16, 4});
|
||||||
|
configs.push_back({dtype, AT, BT, 128, 32, 32, 4});
|
||||||
|
configs.push_back({dtype, AT, BT, 32, 128, 32, 4});
|
||||||
|
if(dtype == HALF){
|
||||||
|
configs.push_back({dtype, AT, BT, 128, 32, 64, 4});
|
||||||
|
configs.push_back({dtype, AT, BT, 32, 128, 64, 4});
|
||||||
|
}
|
||||||
|
// 8 warps
|
||||||
|
configs.push_back({dtype, AT, BT, 128, 256, 16, 8});
|
||||||
|
configs.push_back({dtype, AT, BT, 256, 128, 16, 8});
|
||||||
|
if(dtype == HALF){
|
||||||
|
configs.push_back({dtype, AT, BT, 256, 128, 32, 8});
|
||||||
|
configs.push_back({dtype, AT, BT, 256, 128, 32, 8});
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
// test
|
// test
|
||||||
dtype_t dtype;
|
dtype_t dtype;
|
||||||
bool AT, BT;
|
bool AT, BT;
|
||||||
int M, N, K, TM, TN, TK, nwarp;
|
int M, N, K, TM, TN, TK, nwarp;
|
||||||
for(const auto& c: configs){
|
for(const auto& c: configs){
|
||||||
std::tie(dtype, AT, BT, M, N, K, TM, TN, TK, nwarp) = c;
|
std::tie(dtype, AT, BT, TM, TN, TK, nwarp) = c;
|
||||||
|
M = TM;
|
||||||
|
N = TN;
|
||||||
|
K = TK;
|
||||||
std::cout << "Testing " << c << " ... " << std::flush;
|
std::cout << "Testing " << c << " ... " << std::flush;
|
||||||
if(test_dot(context, stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp))
|
if(test_dot(context, stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp))
|
||||||
std::cout << " Pass! " << std::endl;
|
std::cout << " Pass! " << std::endl;
|
||||||
|
@@ -20,12 +20,15 @@ int main() {
|
|||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<std::vector<int>, int, reduce_op_t> config_t;
|
typedef std::tuple<std::vector<int>, int, reduce_op_t> config_t;
|
||||||
std::vector<config_t> configs = {
|
std::vector<config_t> configs = {
|
||||||
config_t{{8, 8, 4}, 2, ADD},
|
config_t{{64}, 0, ADD},
|
||||||
config_t{{32}, 0, ADD},
|
config_t{{128}, 0, MIN},
|
||||||
config_t{{32, 32}, 0, MAX},
|
config_t{{32, 32}, 0, MAX},
|
||||||
config_t{{32, 32}, 1, ADD},
|
config_t{{32, 32}, 1, ADD},
|
||||||
config_t{{32, 64}, 0, ADD},
|
config_t{{32, 64}, 0, ADD},
|
||||||
config_t{{64, 32}, 1, ADD}
|
config_t{{64, 32}, 1, ADD},
|
||||||
|
config_t{{8, 8, 4}, 2, ADD},
|
||||||
|
config_t{{8, 8, 4}, 0, ADD},
|
||||||
|
config_t{{8, 8, 4}, 1, ADD}
|
||||||
};
|
};
|
||||||
// does the work
|
// does the work
|
||||||
int axis;
|
int axis;
|
||||||
|
Reference in New Issue
Block a user