Compare commits
20 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
2d6df9b518 | ||
|
1b842f8e5e | ||
|
d3e584d4ba | ||
|
d35014ba47 | ||
|
5ce1b726dc | ||
|
858dec8372 | ||
|
90ded16c32 | ||
|
abbc554838 | ||
|
9b32075062 | ||
|
c2e6b90ff1 | ||
|
bfacc191b3 | ||
|
f5ad168686 | ||
|
c3c0ff0552 | ||
|
9e9d781912 | ||
|
d5f20dbce0 | ||
|
d4baad426d | ||
|
5123db0b7d | ||
|
12b6158c5c | ||
|
b352b16567 | ||
|
d132b7442b |
7
.github/workflows/integration-tests.yml
vendored
7
.github/workflows/integration-tests.yml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
- v2.0
|
||||
|
||||
|
||||
jobs:
|
||||
@@ -18,12 +19,16 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Clear cache
|
||||
run: |
|
||||
rm -r /tmp/triton/
|
||||
continue-on-error: true
|
||||
|
||||
- name: Install Triton
|
||||
run: |
|
||||
alias python='python3'
|
||||
cd python
|
||||
pip3 install -e .
|
||||
rm -r /tmp/triton/
|
||||
|
||||
- name: Unit tests
|
||||
run: |
|
||||
|
@@ -26,7 +26,6 @@ Version 1.1 is out! New features include:
|
||||
- Automatic on-disk caching of compiled binary objects
|
||||
- Random Number Generation
|
||||
- Faster (up to 2x on A100), cleaner blocksparse ops
|
||||
- Fixed the semantics of comparison with NaN to match that of Python
|
||||
|
||||
# Contributing
|
||||
|
||||
|
@@ -45,7 +45,7 @@ You can then test your installation by running the unit tests:
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -r requirements-test.txt
|
||||
pytest -vs .
|
||||
pytest -vs test/unit/
|
||||
|
||||
and the benchmarks
|
||||
|
||||
|
@@ -27,7 +27,8 @@ private:
|
||||
void update_graph_trans(ir::instruction *i);
|
||||
void update_graph_broadcast(ir::instruction *i);
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i, bool connect_ret=true);
|
||||
void update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async=false);
|
||||
void update_graph_no_edge(ir::instruction *i);
|
||||
void update_graph(ir::instruction *i);
|
||||
|
||||
|
@@ -33,7 +33,7 @@ namespace codegen{
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
|
||||
codegen::target* target,
|
||||
int sm, int num_warps,
|
||||
int num_stages, bool force_nc_cache, int &shared_static);
|
||||
int num_stages, int &shared_static);
|
||||
|
||||
|
||||
}
|
||||
|
@@ -122,8 +122,7 @@ public:
|
||||
analysis::allocation *alloc,
|
||||
analysis::swizzle *swizzle,
|
||||
target *tgt,
|
||||
unsigned num_warps,
|
||||
bool force_nc_cache = false);
|
||||
unsigned num_warps);
|
||||
|
||||
void visit_value(ir::value* v);
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
@@ -148,12 +147,14 @@ public:
|
||||
void visit_store_inst(ir::store_inst*);
|
||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||
void visit_cat_inst(ir::cat_inst*);
|
||||
void visit_reshape_inst(ir::reshape_inst*);
|
||||
void visit_splat_inst(ir::splat_inst*);
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
void visit_exp_inst(ir::exp_inst*);
|
||||
void visit_cos_inst(ir::cos_inst*);
|
||||
void visit_umulhi_inst(ir::umulhi_inst* x);
|
||||
void visit_sin_inst(ir::sin_inst*);
|
||||
void visit_log_inst(ir::log_inst*);
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
@@ -213,7 +214,6 @@ private:
|
||||
std::set<ir::value*> seen_;
|
||||
|
||||
unsigned num_warps_;
|
||||
bool force_nc_cache_;
|
||||
|
||||
std::map<analysis::data_layout*, Value*> offset_a_m_;
|
||||
std::map<analysis::data_layout*, Value*> offset_a_k_;
|
||||
|
@@ -130,13 +130,14 @@ public:
|
||||
value *create_xor(value *lhs, value *rhs);
|
||||
value *create_or(value *lhs, value *rhs);
|
||||
// Input/Output
|
||||
value *create_load(value *arg);
|
||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache);
|
||||
value *create_store(value *ptr, value *val);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||
// Block instruction
|
||||
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_cat(value *lhs, value *rhs);
|
||||
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
|
||||
// Built-in instruction
|
||||
value *create_get_program_id(unsigned axis);
|
||||
@@ -153,8 +154,10 @@ public:
|
||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
||||
value *create_select(value *pred, value *if_value, value *else_value);
|
||||
// Intrinsics
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
value *create_umulhi(value* lhs, value* rhs);
|
||||
value *create_copy_to_shared(value *arg);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
|
||||
value *create_copy_from_shared(value *arg);
|
||||
value *create_barrier(const std::string &name = "");
|
||||
value *create_async_wait(int N);
|
||||
|
@@ -61,13 +61,14 @@ struct dispatch{
|
||||
|
||||
// casting ops
|
||||
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
|
||||
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
|
||||
// memory operators
|
||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder);
|
||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, ir::builder *builder);
|
||||
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
|
||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
@@ -90,6 +91,7 @@ struct dispatch{
|
||||
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
|
||||
// math
|
||||
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
|
||||
static ir::value *exp(ir::value *x, ir::builder *builder);
|
||||
static ir::value *log(ir::value *x, ir::builder *builder);
|
||||
static ir::value *cos(ir::value *x, ir::builder *builder);
|
||||
|
@@ -132,6 +132,7 @@ enum value_id_t: unsigned {
|
||||
// retile
|
||||
INST_RESHAPE,
|
||||
INST_SPLAT,
|
||||
INST_CAT,
|
||||
INST_BROADCAST,
|
||||
INST_DOWNCAST,
|
||||
// builtin
|
||||
@@ -142,6 +143,7 @@ enum value_id_t: unsigned {
|
||||
INST_ATOMIC_EXCH,
|
||||
INST_ATOMIC_RMW,
|
||||
// math
|
||||
INST_UMULHI,
|
||||
INST_EXP,
|
||||
INST_COS,
|
||||
INST_SIN,
|
||||
|
@@ -394,22 +394,38 @@ public:
|
||||
|
||||
// load
|
||||
class load_inst: public io_inst {
|
||||
public:
|
||||
enum CACHE_MODIFIER : uint32_t {
|
||||
NONE=0,
|
||||
CA,
|
||||
CG,
|
||||
};
|
||||
|
||||
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
|
||||
protected:
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops,
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
std::string get_cache_modifier_repr() const {
|
||||
if (cache_ == CA) return ".ca";
|
||||
if (cache_ == CG) return ".cg";
|
||||
return "";
|
||||
}
|
||||
CACHE_MODIFIER cache_;
|
||||
|
||||
private:
|
||||
static type *get_pointee_type(type *ty);
|
||||
|
||||
};
|
||||
|
||||
// unmasked load
|
||||
class unmasked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_load"; }
|
||||
unmasked_load_inst(value *ptr, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
|
||||
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static unmasked_load_inst* create(value *ptr,
|
||||
CACHE_MODIFIER cache,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_load_inst)
|
||||
@@ -419,8 +435,8 @@ public:
|
||||
// masked load
|
||||
class masked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load"; }
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value,
|
||||
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
@@ -429,6 +445,7 @@ public:
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_inst)
|
||||
@@ -438,8 +455,8 @@ public:
|
||||
// masked load async
|
||||
class masked_load_async_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load_async_async"; }
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); }
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
@@ -448,6 +465,7 @@ public:
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_async_inst)
|
||||
@@ -502,6 +520,21 @@ public:
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cat
|
||||
|
||||
class cat_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "cat"; }
|
||||
cat_inst(value *x, value *y, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static instruction* create(value *lhs, value *rhs,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(cat_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cat_inst)
|
||||
};
|
||||
|
||||
// retile
|
||||
|
||||
class retile_inst: public unary_inst {
|
||||
@@ -636,6 +669,17 @@ public:
|
||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class umulhi_inst: public builtin_inst {
|
||||
private:
|
||||
umulhi_inst(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "umulhi"; }
|
||||
_TRITON_DEFINE_CLONE(umulhi_inst)
|
||||
_TRITON_DEFINE_ACCEPT(umulhi_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class exp_inst: public builtin_inst {
|
||||
private:
|
||||
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -785,6 +829,7 @@ public:
|
||||
// intrinsics classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
class copy_to_shared_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
@@ -866,35 +911,6 @@ public:
|
||||
instruction *next=nullptr);
|
||||
};
|
||||
|
||||
//// On NVIDIA, implementation is such that
|
||||
//// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
||||
//// so as to enable re-association on nv_static_program_idx which is constant
|
||||
//class make_range_dyn: public instruction {
|
||||
//private:
|
||||
// make_range_dyn(type *ty, const std::string &name, instruction *next);
|
||||
// std::string repr_impl() const { return "nv_dynamic_program_idx"; }
|
||||
// _TRITON_DEFINE_CLONE(make_range_dyn)
|
||||
// _TRITON_DEFINE_ACCEPT(make_range_dyn)
|
||||
|
||||
//public:
|
||||
// static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
||||
//};
|
||||
|
||||
//class make_range_sta: public constant {
|
||||
//private:
|
||||
// make_range_sta(make_range *range);
|
||||
|
||||
//public:
|
||||
// static make_range_sta *get(make_range* range);
|
||||
// make_range* get_range() const;
|
||||
// std::string repr() const { return "nv_static_program_idx"; }
|
||||
// _TRITON_DEFINE_ACCEPT(make_range_sta)
|
||||
|
||||
//private:
|
||||
// make_range *range_;
|
||||
//};
|
||||
|
||||
|
||||
/* constant range */
|
||||
class make_range: public instruction{
|
||||
make_range(type *ty, constant_int* first, constant_int* last);
|
||||
|
@@ -45,9 +45,11 @@ class masked_store_inst;
|
||||
class retile_inst;
|
||||
class reshape_inst;
|
||||
class splat_inst;
|
||||
class cat_inst;
|
||||
class broadcast_inst;
|
||||
class downcast_inst;
|
||||
|
||||
class umulhi_inst;
|
||||
class exp_inst;
|
||||
class cos_inst;
|
||||
class sin_inst;
|
||||
@@ -122,6 +124,7 @@ public:
|
||||
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||
|
||||
virtual void visit_umulhi_inst(umulhi_inst*) = 0;
|
||||
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||
virtual void visit_cos_inst(cos_inst*) = 0;
|
||||
virtual void visit_sin_inst(sin_inst*) = 0;
|
||||
@@ -129,6 +132,7 @@ public:
|
||||
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||
virtual void visit_cat_inst(cat_inst*) = 0;
|
||||
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
||||
virtual void visit_downcast_inst(downcast_inst*) = 0;
|
||||
|
||||
@@ -150,13 +154,10 @@ public:
|
||||
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||
// virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
|
||||
|
||||
virtual void visit_function(function*) = 0;
|
||||
|
||||
// virtual void visit_make_range_sta(make_range_sta*) = 0;
|
||||
virtual void visit_undef_value(undef_value*) = 0;
|
||||
virtual void visit_constant_int(constant_int*) = 0;
|
||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||
|
@@ -79,19 +79,28 @@ void axes::update_graph_dot(ir::instruction *i) {
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
|
||||
void axes::update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
ir::value *op = i->get_operand(0);
|
||||
if(!op->get_type()->is_block_ty())
|
||||
return;
|
||||
auto rank = op->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
if(connect_ret && !i->get_type()->is_void_ty())
|
||||
graph_.add_edge({i, d}, {opx, d});
|
||||
graph_.add_edge({opx, d}, {opy, d});
|
||||
for(unsigned d = 0; d < rank; d++) {
|
||||
// If we are dealing with a masked async load we need to attach the
|
||||
// dimensions so we match the behaviour of the copy_to_shared instruction
|
||||
// which async masked load replaces.
|
||||
if (is_masked_load_async) {
|
||||
graph_.add_edge({i, d}, {i, d});
|
||||
}
|
||||
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()) {
|
||||
if(!is_masked_load_async && !i->get_type()->is_void_ty())
|
||||
graph_.add_edge({i, d}, {opx, d});
|
||||
graph_.add_edge({opx, d}, {opy, d});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,12 +116,13 @@ void axes::update_graph(ir::instruction *i) {
|
||||
switch (i->get_id()) {
|
||||
case ir::INST_REDUCE: return update_graph_reduce(i);
|
||||
case ir::INST_RESHAPE: return update_graph_reshape(i);
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);;
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);
|
||||
case ir::INST_CAT: return update_graph_elementwise(i, true);
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, false);
|
||||
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true);
|
||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
|
||||
default: return update_graph_elementwise(i);
|
||||
|
@@ -198,21 +198,24 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
||||
|
||||
ir::value *ptr = nullptr;
|
||||
|
||||
|
||||
std::vector<ir::value*> ptrs;
|
||||
for(ir::value *v: values)
|
||||
for(ir::user *usr: v->get_users())
|
||||
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
|
||||
if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank())
|
||||
ptr = io->get_pointer_operand();
|
||||
}
|
||||
for(ir::user *usr: v->get_users())
|
||||
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
|
||||
if(ptrs.empty() || ptrs[0]->get_type()->get_tile_rank() <= io->get_pointer_operand()->get_type()->get_tile_rank())
|
||||
ptrs.push_back(io->get_pointer_operand());
|
||||
}
|
||||
|
||||
unsigned i = order_[0];
|
||||
int contiguous = 1;
|
||||
if(ptr){
|
||||
for(ir::value* ptr: ptrs){
|
||||
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
|
||||
contiguous = std::min<int>(align->get(ptr, i), 128 / nbits);
|
||||
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
|
||||
}
|
||||
|
||||
|
||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
size /= shape_[i];
|
||||
@@ -496,6 +499,7 @@ void layouts::run(ir::module &mod) {
|
||||
make_graph(i);
|
||||
});
|
||||
|
||||
|
||||
// connected components
|
||||
graph_.connected_components(&values_, &groups_);
|
||||
|
||||
@@ -537,6 +541,7 @@ void layouts::run(ir::module &mod) {
|
||||
tmp_[atom] = id;
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -25,7 +25,7 @@ namespace codegen {
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||
int cc, int num_warps, int num_stages, bool force_nc_cache, int& shared_static) {
|
||||
int cc, int num_warps, int num_stages, int& shared_static) {
|
||||
// generate llvm code
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
@@ -46,7 +46,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::prefetch prefetch_s(target);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps, force_nc_cache);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
|
||||
// run passes
|
||||
dce.run(ir);
|
||||
peephole.run(ir);
|
||||
|
@@ -162,7 +162,7 @@ Type *generator::cvt(ir::type *ty) {
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
||||
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
||||
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
|
||||
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_);
|
||||
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_);
|
||||
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
|
||||
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
|
||||
@@ -197,9 +197,9 @@ generator::generator(analysis::axes *a_axes,
|
||||
analysis::allocation *alloc,
|
||||
analysis::swizzle *swizzle,
|
||||
target *tgt,
|
||||
unsigned num_warps, bool force_nc_cache)
|
||||
unsigned num_warps)
|
||||
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
|
||||
tgt_(tgt), num_warps_(num_warps), force_nc_cache_(force_nc_cache), add(&builder_), mul(&builder_), gep(&builder_) {
|
||||
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
|
||||
|
||||
}
|
||||
|
||||
@@ -629,10 +629,9 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
// -----
|
||||
std::ostringstream asm_oss;
|
||||
asm_oss << "@$" << n_words; // predicate
|
||||
// if(force_nc_cache_)
|
||||
asm_oss << " ld.global";
|
||||
// else
|
||||
// asm_oss << " ld.global.cg";
|
||||
asm_oss << " ld.global";
|
||||
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
|
||||
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
|
||||
if(n_words > 1)
|
||||
asm_oss << ".v" << n_words; // vector width
|
||||
asm_oss << ".b" << width; // word size
|
||||
@@ -775,6 +774,22 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
|
||||
visit_store_inst(x);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `cat`
|
||||
*/
|
||||
void generator::visit_cat_inst(ir::cat_inst* x) {
|
||||
auto idxs = idxs_.at(x);
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
int i = 0;
|
||||
for(size_t j = 0; j < idxs_.at(lhs).size(); j ++)
|
||||
vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]];
|
||||
for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){
|
||||
vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `reshape`
|
||||
@@ -862,6 +877,20 @@ void generator::visit_cos_inst(ir::cos_inst* x){
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `umulhi`
|
||||
*/
|
||||
void generator::visit_umulhi_inst(ir::umulhi_inst* x){
|
||||
std::vector<llvm::Type*> tys = {i32_ty, i32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(i32_ty, tys, false);
|
||||
InlineAsm *umulhi = InlineAsm::get(fn_ty, "mul.hi.u32 $0, $1, $2;", "=r,r,r", false);
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value* lhs = vals_[x->get_operand(0)][idx];
|
||||
Value* rhs = vals_[x->get_operand(1)][idx];
|
||||
vals_[x][idx] = call(umulhi, std::vector<llvm::Value*>{lhs, rhs});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `sin`
|
||||
*/
|
||||
@@ -2197,7 +2226,8 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
||||
|
||||
void generator::visit_make_range(ir::make_range* x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
vals_[x][idx] = idx[0];
|
||||
Value* start = ConstantInt::get(idx[0]->getType(), x->get_first()->get_value());
|
||||
vals_[x][idx] = add(start, idx[0]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -11,6 +11,8 @@ namespace transform{
|
||||
|
||||
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
|
||||
std::set<ir::value*>& seen) {
|
||||
if (dynamic_cast<ir::phi_node*>(root))
|
||||
return root;
|
||||
if(!seen.insert(root).second)
|
||||
return root;
|
||||
if(!root->get_type()->is_block_ty())
|
||||
|
@@ -116,7 +116,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
|
||||
int nts = layout->nts(layout->get_order()[0]);
|
||||
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(nts*dtsize >= 4){
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier());
|
||||
copy_to_shared->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
@@ -206,7 +206,8 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
|
||||
builder.set_insert_point(select);
|
||||
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
|
||||
if_value->get_mask_operand(),
|
||||
select->get_else_value_op());
|
||||
select->get_else_value_op(),
|
||||
if_value->get_cache_modifier());
|
||||
select->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
|
@@ -101,21 +101,33 @@ void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir:
|
||||
}
|
||||
}
|
||||
|
||||
struct pipeline_info_t {
|
||||
ir::load_inst* load;
|
||||
ir::phi_node* ptr;
|
||||
ir::dot_inst* dot;
|
||||
|
||||
pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot)
|
||||
: load(load), ptr(ptr), dot(dot) {}
|
||||
};
|
||||
|
||||
void pipeline::run(ir::module &mod) {
|
||||
if (num_stages_ <= 1)
|
||||
return;
|
||||
// *Very* conservative heuristics for pre-fetching.
|
||||
// A load instruction can be pipelined if:
|
||||
// - the pointer is a phi node that references a value
|
||||
// in its basic block (i.e., pointer induction variable)
|
||||
// - the load has only a single use in a dot instruction
|
||||
// As more use cases become apparent, this pass will be improved
|
||||
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
|
||||
std::vector<pipeline_info_t> to_pipeline;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
|
||||
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
|
||||
auto users = load->get_users();
|
||||
auto dot = dynamic_cast<ir::dot_inst*>(*users.begin());
|
||||
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
|
||||
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
|
||||
to_pipeline.push_back({load, ptr});
|
||||
&& users.size() == 1 && dot)
|
||||
to_pipeline.push_back({load, ptr, dot});
|
||||
}});
|
||||
// do the pipelining
|
||||
std::vector<ir::phi_node*> new_loads;
|
||||
@@ -123,8 +135,8 @@ void pipeline::run(ir::module &mod) {
|
||||
const int num_stages = num_stages_;
|
||||
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
||||
for(auto info: to_pipeline){
|
||||
ir::load_inst* load = info.first;
|
||||
ir::phi_node* ptr = info.second;
|
||||
ir::load_inst* load = info.load;
|
||||
ir::phi_node* ptr = info.ptr;
|
||||
ir::basic_block* block = load->get_parent();
|
||||
ir::basic_block* header = block->get_predecessors()[0];
|
||||
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
|
||||
@@ -166,7 +178,7 @@ void pipeline::run(ir::module &mod) {
|
||||
false_value = remat_false_value;
|
||||
} else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value);
|
||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier());
|
||||
|
||||
for (int stage = 1; stage < num_stages-1; ++stage) {
|
||||
// mask is the loop condition of the previous iteration
|
||||
@@ -181,7 +193,7 @@ void pipeline::run(ir::module &mod) {
|
||||
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value);
|
||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier());
|
||||
}
|
||||
|
||||
// create new phis for induction variables
|
||||
@@ -210,7 +222,7 @@ void pipeline::run(ir::module &mod) {
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier());
|
||||
|
||||
|
||||
// phi node
|
||||
@@ -245,7 +257,7 @@ void pipeline::run(ir::module &mod) {
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier());
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
@@ -256,7 +268,7 @@ void pipeline::run(ir::module &mod) {
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier());
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||
@@ -288,22 +300,23 @@ void pipeline::run(ir::module &mod) {
|
||||
std::vector<ir::instruction*> insts;
|
||||
ir::load_inst* dst;
|
||||
};
|
||||
std::map<ir::basic_block*, move_config_t> to_move;
|
||||
std::vector<move_config_t> to_move(to_pipeline.size());
|
||||
|
||||
if(has_copy_async_){
|
||||
for(ir::function* fn: mod.get_function_list())
|
||||
for(ir::basic_block* bb: fn->blocks())
|
||||
for(ir::instruction* inst: bb->get_inst_list()){
|
||||
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
|
||||
recursive_deps(i, bb, to_move[bb].insts);
|
||||
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
|
||||
to_move[bb].dst = i;
|
||||
for (size_t idx = 0; idx < to_pipeline.size(); ++idx) {
|
||||
auto info = to_pipeline[idx];
|
||||
ir::load_inst* load = info.load;
|
||||
ir::phi_node* ptr = info.ptr;
|
||||
ir::dot_inst* dot = info.dot;
|
||||
ir::basic_block* bb = dot->get_parent();
|
||||
recursive_deps(dot, bb, to_move[idx].insts);
|
||||
to_move[idx].dst = load;
|
||||
}
|
||||
|
||||
for(auto& x: to_move){
|
||||
builder.set_insert_point_after(x.second.dst);
|
||||
for(ir::instruction* i: x.second.insts){
|
||||
x.first->erase(i);
|
||||
for(auto& move_config: to_move){
|
||||
builder.set_insert_point_after(move_config.dst);
|
||||
for(ir::instruction* i: move_config.insts){
|
||||
i->get_parent()->erase(i);
|
||||
builder.insert(i);
|
||||
}
|
||||
}
|
||||
|
@@ -178,7 +178,7 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
|
||||
ofs.close();
|
||||
std::string cmd;
|
||||
int err;
|
||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o";
|
||||
err = system(cmd.c_str());
|
||||
CUmodule ret;
|
||||
std::ifstream _cubin(_fbin, std::ios::binary );
|
||||
|
@@ -273,16 +273,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
|
||||
// load/store instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_load(value *ptr){
|
||||
return insert(unmasked_load_inst::create(ptr));
|
||||
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache){
|
||||
return insert(unmasked_load_inst::create(ptr, cache));
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val){
|
||||
return insert(unmasked_store_inst::create(ptr, val));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value));
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, cache));
|
||||
}
|
||||
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
||||
@@ -297,6 +297,10 @@ value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(reshape_inst::create(arg, shapes));
|
||||
}
|
||||
|
||||
value *builder::create_cat(value *lhs, value *rhs) {
|
||||
return insert(cat_inst::create(lhs, rhs));
|
||||
}
|
||||
|
||||
value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(splat_inst::create(arg, shapes));
|
||||
}
|
||||
@@ -369,6 +373,9 @@ value *builder::create_select(value *pred, value *if_value, value *else_value){
|
||||
// intrinsic instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_umulhi(value *lhs, value *rhs) {
|
||||
return insert(umulhi_inst::create(lhs, rhs));
|
||||
}
|
||||
|
||||
value *builder::create_copy_to_shared(value *arg) {
|
||||
return insert(copy_to_shared_inst::create(arg));
|
||||
@@ -379,8 +386,8 @@ value *builder::create_copy_from_shared(value *arg) {
|
||||
return insert(copy_from_shared_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value));
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache));
|
||||
}
|
||||
|
||||
value *builder::create_barrier(const std::string &name) {
|
||||
|
@@ -368,6 +368,10 @@ ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *b
|
||||
return builder->create_reshape(input, dst_shape);
|
||||
}
|
||||
|
||||
ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||
return builder->create_cat(lhs, rhs);
|
||||
}
|
||||
|
||||
ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) {
|
||||
if (!input->get_type()->is_block_ty())
|
||||
return builder->create_splat(input, shape);
|
||||
@@ -497,7 +501,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
// Memory Operators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder* builder) {
|
||||
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, ir::builder* builder) {
|
||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
|
||||
if(ptr->get_type()->is_block_ty()){
|
||||
@@ -517,8 +521,17 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir:
|
||||
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
|
||||
ptr = dispatch::cast(ptr, ptr_ty, builder);
|
||||
}
|
||||
load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default
|
||||
if (!cache_modifier.empty()) {
|
||||
if (cache_modifier == ".ca")
|
||||
cache = load_inst::CA;
|
||||
else if (cache_modifier == ".cg")
|
||||
cache = load_inst::CG;
|
||||
else
|
||||
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
|
||||
}
|
||||
if (!mask && !other)
|
||||
return builder->create_load(ptr);
|
||||
return builder->create_load(ptr, cache);
|
||||
if (!mask)
|
||||
throw std::runtime_error("`other` cannot be provided without `mask`");
|
||||
auto shape = ptr->get_type()->get_block_shapes();
|
||||
@@ -527,7 +540,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir:
|
||||
if(ptr->get_type()->is_block_ty())
|
||||
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
|
||||
}
|
||||
return builder->create_masked_load(ptr, mask, other);
|
||||
return builder->create_masked_load(ptr, mask, other, cache);
|
||||
}
|
||||
|
||||
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
|
||||
@@ -706,6 +719,11 @@ ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *build
|
||||
// Math
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) {
|
||||
binary_op_type_checking(x, y, builder);
|
||||
return builder->insert(umulhi_inst::create(x, y));
|
||||
}
|
||||
|
||||
ir::value *dispatch::exp(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_exp(x);
|
||||
}
|
||||
|
@@ -433,8 +433,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n
|
||||
{ }
|
||||
|
||||
// load_inst
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next)
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache)
|
||||
{ }
|
||||
|
||||
// load
|
||||
@@ -447,41 +447,44 @@ type *load_inst::get_pointee_type(type *ty) {
|
||||
}
|
||||
|
||||
// unmasked_load
|
||||
unmasked_load_inst::unmasked_load_inst(value *ptr, const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, name, next) {
|
||||
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, name, next) {
|
||||
set_operand(0, ptr);
|
||||
}
|
||||
|
||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, const std::string &name, instruction *next) {
|
||||
return new unmasked_load_inst(ptr, name, next);
|
||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) {
|
||||
return new unmasked_load_inst(ptr, cache, name, next);
|
||||
}
|
||||
|
||||
// masked load
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value,
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD, 3, name, next) {
|
||||
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_inst(ptr, mask, false_value, name, next);
|
||||
return new masked_load_inst(ptr, mask, false_value, cache, name, next);
|
||||
}
|
||||
|
||||
// masked load async
|
||||
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, name, next) {
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_async_inst(ptr, mask, false_value, name, next);
|
||||
return new masked_load_async_inst(ptr, mask, false_value, cache, name, next);
|
||||
}
|
||||
|
||||
// store
|
||||
@@ -519,11 +522,28 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cat
|
||||
|
||||
cat_inst::cat_inst(value *x, value *y, const std::string &name, instruction *next)
|
||||
: instruction(block_type::get(x->get_type()->get_scalar_ty(),
|
||||
{x->get_type()->get_block_shapes()[0] +
|
||||
y->get_type()->get_block_shapes()[0] }), INST_CAT, 2, name, next) {
|
||||
set_operand(0, x);
|
||||
set_operand(1, y);
|
||||
}
|
||||
|
||||
instruction* cat_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||
return new cat_inst(lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
// retile
|
||||
|
||||
retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes,
|
||||
const std::string &name, instruction *next)
|
||||
: unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
|
||||
|
||||
|
||||
|
||||
// reshape
|
||||
|
||||
instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes,
|
||||
@@ -758,6 +778,19 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
|
||||
}
|
||||
|
||||
|
||||
// umulhi
|
||||
|
||||
umulhi_inst::umulhi_inst(value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: builtin_inst(lhs->get_type(), INST_UMULHI, 2, name, next) {
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
|
||||
instruction* umulhi_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||
return new umulhi_inst(lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
|
||||
// exp
|
||||
|
||||
exp_inst::exp_inst(value *val, const std::string &name, instruction *next)
|
||||
@@ -874,8 +907,8 @@ make_range::make_range(type *ty, constant_int *first, constant_int *last)
|
||||
make_range *make_range::create(constant_int *first, constant_int *last) {
|
||||
assert(first->get_type()->is_integer_ty());
|
||||
assert(first->get_type() == last->get_type());
|
||||
assert(((constant_int*)first)->get_value() == 0);
|
||||
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value()});
|
||||
// assert(((constant_int*)first)->get_value() == 0);
|
||||
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()});
|
||||
return new make_range(ty, first, last);
|
||||
}
|
||||
|
||||
|
@@ -126,7 +126,7 @@ void SlotTracker::create_function_slot(const value *v) {
|
||||
}
|
||||
|
||||
int SlotTracker::get_local_slot(const value *v) {
|
||||
assert(dynamic_cast<constant>(v) == nullptr && "Can't get a constant slot");
|
||||
assert(dynamic_cast<const constant*>(v) == nullptr && "Can't get a constant slot");
|
||||
|
||||
// Check for uninitialized state and do lazy initialization.
|
||||
initialize_if_needed();
|
||||
|
@@ -121,7 +121,7 @@ class CMakeBuild(build_ext):
|
||||
|
||||
setup(
|
||||
name="triton",
|
||||
version="1.1.0",
|
||||
version="1.1.2",
|
||||
author="Philippe Tillet",
|
||||
author_email="phil@openai.com",
|
||||
description="A language and compiler for custom Deep Learning operations",
|
||||
|
@@ -314,7 +314,7 @@ std::string zeros_docstr = R"pbdoc(
|
||||
|
||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||
:type shape: tuple of ints
|
||||
:param dtype: Data-type of the new array, e.g., triton.float16
|
||||
:param dtype: Data-type of the new array, e.g., tl.float16
|
||||
:type dtype: triton.ir.dtype
|
||||
)pbdoc";
|
||||
ir::value *zeros(ir::type::block_shapes_t shape, type_code _dtype, ir::builder *builder) {
|
||||
@@ -673,4 +673,4 @@ ir::value *subscript(ir::value *self, std::vector<py::object> slices, ir::builde
|
||||
}
|
||||
}
|
||||
return builder->create_reshape(self, shape);
|
||||
}
|
||||
}
|
||||
|
@@ -203,7 +203,7 @@ std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_
|
||||
// CUDA
|
||||
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
|
||||
uint64_t device, int num_warps, int num_stages,
|
||||
bool force_nc_cache, asm_map_t &asm_map){
|
||||
asm_map_t &asm_map){
|
||||
llvm::LLVMContext ctx;
|
||||
// device properties
|
||||
CUdevice dev = (CUdevice)device;
|
||||
@@ -215,7 +215,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
triton::codegen::nvidia_cu_target target(cc);
|
||||
int n_shared_bytes;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
|
||||
std::string tmp;
|
||||
llvm::raw_string_ostream llir(tmp);
|
||||
llir << *llvm;
|
||||
@@ -236,12 +236,12 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
||||
// HIP
|
||||
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir,
|
||||
uint64_t device, int num_warps, int num_stages,
|
||||
bool force_nc_cache, asm_map_t &asm_map){
|
||||
asm_map_t &asm_map){
|
||||
llvm::LLVMContext ctx;
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
triton::codegen::amd_cl_target target;
|
||||
int n_shared_bytes;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes);
|
||||
std::string tmp;
|
||||
llvm::raw_string_ostream llir(tmp);
|
||||
llir << *llvm;
|
||||
@@ -255,7 +255,7 @@ std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def(
|
||||
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) {
|
||||
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) {
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
// record asm as we generate
|
||||
asm_map_t asm_map;
|
||||
@@ -264,9 +264,9 @@ void init_triton_codegen(py::module &&m) {
|
||||
asm_map["ttir"] = py::cast(ttir.str());
|
||||
llvm::LLVMContext ctx;
|
||||
if(backend == CUDA)
|
||||
return cu_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
|
||||
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
||||
if(backend == ROCM)
|
||||
return hip_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
|
||||
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
||||
}, py::return_value_policy::take_ownership);
|
||||
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||
if(backend == CUDA)
|
||||
@@ -313,6 +313,7 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("arange", &ir::dispatch::arange, ret::reference);
|
||||
m.def("zeros", &ir::dispatch::zeros, ret::reference);
|
||||
// type manipuatation
|
||||
m.def("cat", &ir::dispatch::cat, ret::reference);
|
||||
m.def("reshape", &ir::dispatch::reshape, ret::reference);
|
||||
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
|
||||
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
|
||||
@@ -340,6 +341,7 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("max", &ir::dispatch::max, ret::reference);
|
||||
m.def("sum", &ir::dispatch::sum, ret::reference);
|
||||
// math
|
||||
m.def("umulhi", &ir::dispatch::umulhi, ret::reference);
|
||||
m.def("exp", &ir::dispatch::exp, ret::reference);
|
||||
m.def("log", &ir::dispatch::log, ret::reference);
|
||||
m.def("cos", &ir::dispatch::cos, ret::reference);
|
||||
@@ -476,6 +478,7 @@ void init_triton_ir(py::module &&m) {
|
||||
// constants
|
||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
||||
.def("get_int64", &ir::builder::get_int64, ret::reference)
|
||||
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
||||
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
||||
.def("get_range", &ir::builder::get_range, ret::reference);
|
||||
|
@@ -515,10 +515,113 @@ def test_dot(epilogue, device='cuda'):
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
def test_dot_without_load():
|
||||
@triton.jit
|
||||
def kernel(out, **meta):
|
||||
pid = tl.program_id(axis=0)
|
||||
a = tl.zeros((32, 32), tl.float32)
|
||||
b = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.dot(a, b)
|
||||
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
|
||||
tl.store(pout, c)
|
||||
|
||||
out = torch.ones((32,32), dtype=torch.float32, device="cuda")
|
||||
kernel[(1,)](out)
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
# ---------------
|
||||
|
||||
@pytest.mark.parametrize("start", [0, 1, 7, 16])
|
||||
def test_arange(start, device='cuda'):
|
||||
BLOCK = 128
|
||||
z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device)
|
||||
@triton.jit
|
||||
def _kernel(z, **meta):
|
||||
off = tl.arange(0, meta['BLOCK'])
|
||||
val = tl.arange(meta['START'], meta['END'])
|
||||
tl.store(z + off, val)
|
||||
_kernel[(1,)](z_tri, START=start, END=start+BLOCK, BLOCK=BLOCK)
|
||||
z_ref = torch.arange(start, BLOCK+start, dtype=torch.int32, device=device)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
# 'bfloat16': torch.bfloat16,
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
M = 32
|
||||
N = 32
|
||||
K = 8
|
||||
|
||||
in1 = torch.rand((M, K), dtype=dtype, device=device)
|
||||
in2 = torch.rand((K, N), dtype=dtype, device=device)
|
||||
out = torch.zeros((M, N), dtype=dtype, device=device)
|
||||
|
||||
@triton.jit
|
||||
def _kernel(in1_ptr, in2_ptr, output_ptr,
|
||||
in_stride, in2_stride, out_stride,
|
||||
in_numel, in2_numel, out_numel, **meta):
|
||||
M = meta['M']
|
||||
N = meta['N']
|
||||
K = meta['K']
|
||||
|
||||
M_offsets = tl.arange(0, M)
|
||||
N_offsets = tl.arange(0, N)
|
||||
K_offsets = tl.arange(0, K)
|
||||
|
||||
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None,:]
|
||||
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None,:]
|
||||
|
||||
# Load inputs.
|
||||
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
|
||||
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
|
||||
|
||||
# Without a dot product the memory doesn't get promoted to shared.
|
||||
o = tl.dot(x, w)
|
||||
|
||||
# Store output
|
||||
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None,:]
|
||||
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
|
||||
|
||||
pgm = _kernel[(1,)](in1, in2, out,
|
||||
in1.stride()[0],
|
||||
in2.stride()[0],
|
||||
out.stride()[0],
|
||||
in1.numel(),
|
||||
in2.numel(),
|
||||
out.numel(),
|
||||
M=M, N=N, K=K)
|
||||
|
||||
reference_out =torch.matmul(in1, in2)
|
||||
triton.testing.allclose(out, reference_out)
|
||||
|
||||
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||
def test_load_cache_modifier(cache):
|
||||
src = torch.empty(128, device='cuda')
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst, src, **meta):
|
||||
offsets = tl.arange(0, 128)
|
||||
x = tl.load(src+offsets, cache_modifier=meta['CACHE'])
|
||||
tl.store(dst+offsets, x)
|
||||
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
|
@@ -112,7 +112,7 @@ BLOCK = 1024
|
||||
# test generation of random uint32
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in ['10', '4,53', '10000']\
|
||||
for seed in [0, 42, 124, 54]]
|
||||
for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]
|
||||
)
|
||||
def test_randint(size, seed, device='cuda'):
|
||||
size = list(map(int, size.split(',')))
|
||||
@@ -132,34 +132,6 @@ def test_randint(size, seed, device='cuda'):
|
||||
out_ref = [gen.random_raw()[0] for _ in out_tri]
|
||||
assert out_tri == out_ref
|
||||
|
||||
# test conversion of random uint32 into random float in [0, 1]
|
||||
def test_uint32_to_uniform_float():
|
||||
@triton.jit
|
||||
def kernel(SRC, TGT, N, **meta):
|
||||
pid = tl.program_id(0)
|
||||
offset = pid * BLOCK + tl.arange(0, BLOCK)
|
||||
src = tl.load(SRC + offset)
|
||||
tgt = tl.random.uint32_to_uniform_float(src)
|
||||
tl.store(TGT + offset, tgt, mask=offset < N)
|
||||
|
||||
def run(source):
|
||||
target = -torch.ones(source.shape, dtype=torch.float32, device=source.device)
|
||||
N = source.numel()
|
||||
grid = lambda meta: (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](source, target, N)
|
||||
return target
|
||||
|
||||
# check range of edge values
|
||||
n = 100
|
||||
source = torch.tensor(list(range(n)) + list(range(-n, 0)), dtype=torch.int32).cuda()
|
||||
target = run(source).tolist()
|
||||
assert target == sorted(target)
|
||||
assert all(0.0 <= num < 1.0 for num in target)
|
||||
# check distribution is uniform
|
||||
source = torch.randint(-2**31, 2**31 - 1, dtype=torch.int32, size=(100000,)).cuda()
|
||||
target = run(source).tolist()
|
||||
assert scipy.stats.kstest(target, 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test uniform PRNG
|
||||
@pytest.mark.parametrize('size, seed',
|
||||
[(size, seed) for size in [1000000]\
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# version
|
||||
__version__ = '1.0.1'
|
||||
__version__ = '1.1.1'
|
||||
|
||||
# TODO: torch needs to be imported first
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
|
@@ -103,7 +103,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
arg_values = []
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if i in self.constants:
|
||||
arg_values.append(self.constants[i])
|
||||
cst = triton.language.core._to_ir(self.constants[i], self.builder)
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
if i in self.attributes:
|
||||
is_ptr = fn.args[i].type.is_ptr()
|
||||
@@ -463,9 +464,6 @@ class Kernel:
|
||||
@staticmethod
|
||||
def _type_name(obj):
|
||||
type_names = {
|
||||
int: 'I',
|
||||
float: 'f',
|
||||
bool: 'B',
|
||||
triton.language.float8: 'f8',
|
||||
torch.bfloat16: 'bf16',
|
||||
torch.float16: 'f16',
|
||||
@@ -477,12 +475,25 @@ class Kernel:
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
}
|
||||
return type_names[obj]
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
return type_names[obj.dtype]
|
||||
if isinstance(obj, int):
|
||||
if abs(obj) <= 0xffffffff:
|
||||
return 'I'
|
||||
return 'L'
|
||||
if isinstance(obj, float):
|
||||
return 'f'
|
||||
if isinstance(obj, bool):
|
||||
return 'B'
|
||||
assert False
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _to_triton_ir(context, obj):
|
||||
type_map = {
|
||||
'I': _triton.ir.type.get_int32,
|
||||
'L': _triton.ir.type.get_int64,
|
||||
'f': _triton.ir.type.get_fp32,
|
||||
'B': _triton.ir.type.get_int1,
|
||||
'f8': _triton.ir.type.get_fp8,
|
||||
@@ -498,11 +509,11 @@ class Kernel:
|
||||
}
|
||||
# convert torch.Tensor to Triton IR pointers
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
name = Kernel._type_name(obj.dtype)
|
||||
name = Kernel._type_name(obj)
|
||||
elt_ty = type_map[name](context)
|
||||
return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||
# default path returns triton.ir.type directly
|
||||
name = Kernel._type_name(obj.__class__)
|
||||
name = Kernel._type_name(obj)
|
||||
return type_map[name](context)
|
||||
|
||||
@staticmethod
|
||||
@@ -511,7 +522,7 @@ class Kernel:
|
||||
types_key = [None] * len(wargs)
|
||||
for i, arg in enumerate(wargs):
|
||||
prefix = 'P' if i in tensor_idxs else ''
|
||||
suffix = Kernel._type_name(arg.dtype) if i in tensor_idxs else Kernel._type_name(arg.__class__)
|
||||
suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg)
|
||||
types_key[i] = prefix + suffix
|
||||
return tuple(types_key)
|
||||
|
||||
@@ -526,7 +537,7 @@ class Kernel:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
@@ -549,13 +560,13 @@ class Kernel:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
else:
|
||||
backend = _triton.runtime.backend.ROCM
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
if shared_mem > max_shared_memory:
|
||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
|
||||
# device inference
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
if len(tensor_idxs) == 0:
|
||||
@@ -632,7 +643,7 @@ class Kernel:
|
||||
if binary is None:
|
||||
binary = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
constants=constants, **meta
|
||||
)
|
||||
if bin_cache_path:
|
||||
@@ -646,7 +657,7 @@ class Kernel:
|
||||
|
||||
drv_cache[key] = LoadedBinary(device_idx, binary)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs)])
|
||||
params = struct.pack(fmt, *args)
|
||||
# enqueue cached function into stream
|
||||
callable = drv_cache[key]
|
||||
@@ -715,24 +726,24 @@ class Autotuner:
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
contents = []
|
||||
# frontend
|
||||
with open(triton.code_gen.__file__, "rb") as f:
|
||||
frontend_contents = hashlib.md5(f.read()).hexdigest()
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# backend
|
||||
with open(triton._C.libtriton.__file__, "rb") as f:
|
||||
backend_contents = hashlib.md5(f.read()).hexdigest()
|
||||
|
||||
try:
|
||||
nvcc_version = hashlib.md5(subprocess.check_output(["nvcc", "--version"])).hexdigest()
|
||||
except Exception:
|
||||
nvcc_version = None
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# language
|
||||
for lib in pkgutil.iter_modules(triton.language.__path__):
|
||||
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
||||
contents += [hashlib.md5(f.read()).hexdigest()]
|
||||
# ptxas version
|
||||
try:
|
||||
ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest()
|
||||
except Exception:
|
||||
ptxas_version = None
|
||||
|
||||
return (
|
||||
triton.__version__, frontend_contents, backend_contents,
|
||||
nvcc_version, ptxas_version
|
||||
)
|
||||
return (triton.__version__, ptxas_version) + tuple(contents)
|
||||
|
||||
class JITFunction:
|
||||
|
||||
|
@@ -9,7 +9,9 @@ def _to_ir(x, builder):
|
||||
if isinstance(x, bool):
|
||||
return builder.get_int1(x)
|
||||
elif isinstance(x, int):
|
||||
return builder.get_int32(x)
|
||||
if x.__abs__() <= 2**31:
|
||||
return builder.get_int32(x)
|
||||
return builder.get_int64(x)
|
||||
elif isinstance(x, float):
|
||||
return builder.get_float32(x)
|
||||
if isinstance(x, block):
|
||||
@@ -307,7 +309,7 @@ def zeros(shape, dtype, _builder=None):
|
||||
|
||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||
:type shape: tuple of ints
|
||||
:param dtype: Data-type of the new array, e.g., :code:`triton.float16`
|
||||
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
||||
:type dtype: DType
|
||||
"""
|
||||
shape = [int(x.handle) if isinstance(x, block) else x for x in shape]
|
||||
@@ -344,6 +346,18 @@ def broadcast_to(input, shape, _builder=None):
|
||||
"""
|
||||
return frontend.broadcast_to(input, shape, _builder)
|
||||
|
||||
@builtin
|
||||
def cat(input, other, _builder=None):
|
||||
"""
|
||||
Concatenate the given blocks
|
||||
|
||||
:param input: The first input block.
|
||||
:type input:
|
||||
:param other: The second input block.
|
||||
:type other:
|
||||
"""
|
||||
return frontend.cat(input, other, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def reshape(input, shape, _builder=None):
|
||||
@@ -385,7 +399,7 @@ def dot(input, other, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def load(pointer, mask=None, other=None, _builder=None):
|
||||
def load(pointer, mask=None, other=None, cache_modifier="", _builder=None):
|
||||
"""
|
||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
||||
|
||||
@@ -399,8 +413,10 @@ def load(pointer, mask=None, other=None, _builder=None):
|
||||
:type mask: Block of triton.int1, optional
|
||||
:param other: if mask[idx] is false, return other[idx]
|
||||
:type other: Block, optional
|
||||
:param cache_modifier: changes cache option in nvidia ptx
|
||||
'type cache_modifier: str, optional
|
||||
"""
|
||||
return frontend.load(pointer, mask, other, _builder)
|
||||
return frontend.load(pointer, mask, other, cache_modifier, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@@ -520,6 +536,10 @@ def where(condition, x, y, _builder=None):
|
||||
# Math
|
||||
# -----------------------
|
||||
|
||||
@builtin
|
||||
def umulhi(x, y, _builder=None):
|
||||
return frontend.umulhi(x, y, _builder)
|
||||
|
||||
def _add_math_1arg_docstr(name):
|
||||
|
||||
def _decorator(func):
|
||||
@@ -539,7 +559,6 @@ def _add_math_1arg_docstr(name):
|
||||
def exp(x, _builder=None):
|
||||
return frontend.exp(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("natural logarithm")
|
||||
def log(x, _builder=None):
|
||||
@@ -636,6 +655,10 @@ def max_contiguous(input, value, _builder=None):
|
||||
# Standard library
|
||||
# -----------------------
|
||||
|
||||
@triton.jit
|
||||
def abs(x):
|
||||
return where(x >= 0, x, -x)
|
||||
|
||||
@triton.jit
|
||||
def cdiv(x, div):
|
||||
"""
|
||||
@@ -730,4 +753,4 @@ def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
# new row and column indices
|
||||
new_i = off_i + (ij % size_g)
|
||||
new_j = (ij % size_gj) // size_g
|
||||
return new_i, new_j
|
||||
return new_i, new_j
|
||||
|
@@ -31,42 +31,26 @@ def PHILOX_ROUND_B():
|
||||
# 0xCD9E8D57
|
||||
return -845247145
|
||||
|
||||
|
||||
@triton.jit
|
||||
def hacky_to_uint64(x):
|
||||
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def multiply_low_high(a, b):
|
||||
return (
|
||||
a * b,
|
||||
((hacky_to_uint64(a) * hacky_to_uint64(b)) >> 32).to(tl.int32)
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def single_round(c0, c1, c2, c3, k0, k1):
|
||||
A = PHILOX_ROUND_A()
|
||||
B = PHILOX_ROUND_B()
|
||||
lo0, hi0 = multiply_low_high(A, c0)
|
||||
lo1, hi1 = multiply_low_high(B, c2)
|
||||
|
||||
return (
|
||||
hi1 ^ c1 ^ k0,
|
||||
lo1,
|
||||
hi0 ^ c3 ^ k1,
|
||||
lo0,
|
||||
)
|
||||
_c0, _c2 = c0, c2
|
||||
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
|
||||
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
|
||||
c1 = B * _c2
|
||||
c3 = A * _c0
|
||||
return c0, c1, c2, c3
|
||||
|
||||
|
||||
@triton.jit
|
||||
def raise_key(k0, k1):
|
||||
return (
|
||||
k0 + PHILOX_KEY_A(),
|
||||
k1 + PHILOX_KEY_B(),
|
||||
)
|
||||
|
||||
return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B())
|
||||
|
||||
@triton.jit
|
||||
def philox_f(c0, c1, c2, c3, k0, k1):
|
||||
@@ -100,11 +84,9 @@ def uint32_to_uniform_float(x):
|
||||
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
|
||||
covers all the possible values it can take.
|
||||
"""
|
||||
mantissa = x & 0x7fffff
|
||||
exp = 127
|
||||
res = mantissa | (exp << 23)
|
||||
return res.to(tl.float32, bitcast=True) - 1.0
|
||||
|
||||
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
|
||||
x = tl.where(x < 0, -x - 1, x)
|
||||
return x * max
|
||||
|
||||
@triton.jit
|
||||
def pair_uniform_to_normal(u1, u2):
|
||||
@@ -127,8 +109,11 @@ def randint4x(seed, offset):
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
z = 0
|
||||
return philox_f(offset, z, z, z, seed, z)
|
||||
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting
|
||||
seed = hacky_to_uint64(seed) # uint will solve this
|
||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
||||
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
||||
return philox_f(offset, z, z, z, seed_lo, seed_hi)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
251
python/tutorials/05-layer-norm.py
Normal file
251
python/tutorials/05-layer-norm.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Layer Normalization
|
||||
====================
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
import triton
|
||||
|
||||
# Forward Pass
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
# offset data pointers to start at the row of interest
|
||||
X += row * stride
|
||||
Y += row * stride
|
||||
# load data and cast to float32
|
||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
||||
# compute mean
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
# compute std
|
||||
xmean = tl.where(mask, x - mean, 0.)
|
||||
var = tl.sum(xmean * xmean, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
xhat = xmean*rstd
|
||||
# write-back mean/rstd
|
||||
tl.store(M + row, mean)
|
||||
tl.store(V + row, rstd)
|
||||
# multiply by weight and add bias
|
||||
w = tl.load(W + cols, mask=mask)
|
||||
b = tl.load(B + cols, mask=mask)
|
||||
y = xhat * w + b
|
||||
# write-back
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
# Backward pass (DX + partial DW + partial DB)
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dx_fused(DX, DY, DW, DB, X, W, B, M, V, Lock,
|
||||
stride, N, eps,
|
||||
**META):
|
||||
GROUP_SIZE_M = META['GROUP_SIZE_M']
|
||||
BLOCK_SIZE_N = META['BLOCK_SIZE_N']
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < N
|
||||
# offset data pointers to start at the row of interest
|
||||
X += row * stride
|
||||
DY += row * stride
|
||||
DX += row * stride
|
||||
# offset locks and weight/bias gradient pointer
|
||||
# each kernel instance accumulates partial sums for
|
||||
# DW and DB into one of GROUP_SIZE_M independent buffers
|
||||
# these buffers stay in the L2, which allow this kernel
|
||||
# to be fast
|
||||
lock_id = row % GROUP_SIZE_M
|
||||
Lock += lock_id
|
||||
Count = Lock + GROUP_SIZE_M
|
||||
DW = DW + lock_id*N + cols
|
||||
DB = DB + lock_id*N + cols
|
||||
# load data to SRAM
|
||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
||||
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
mean = tl.load(M + row)
|
||||
rstd = tl.load(V + row)
|
||||
# compute dx
|
||||
xhat = (x - mean)*rstd
|
||||
wdy = w * dy
|
||||
xhat = tl.where(mask, xhat, 0.)
|
||||
wdy = tl.where(mask, wdy , 0.)
|
||||
mean1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
mean2 = tl.sum(wdy, axis=0) / N
|
||||
dx = (wdy - (xhat*mean1 + mean2))*rstd
|
||||
# write-back dx
|
||||
tl.store(DX + cols, dx, mask=mask)
|
||||
# accumulate partial sums for dw/db
|
||||
partial_dw = (dy*xhat).to(w.dtype)
|
||||
partial_db = (dy).to(w.dtype)
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
pass
|
||||
count = tl.load(Count)
|
||||
# first store doesn't accumulate
|
||||
if count == 0:
|
||||
tl.atomic_xchg(Count, 1)
|
||||
else:
|
||||
partial_dw += tl.load(DW, mask=mask)
|
||||
partial_db += tl.load(DB, mask=mask)
|
||||
tl.store(DW, partial_dw, mask=mask)
|
||||
tl.store(DB, partial_db, mask=mask)
|
||||
# release lock
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
|
||||
# Backward pass (total DW + total DB)
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, **meta):
|
||||
pid = tl.program_id(0)
|
||||
BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
|
||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
||||
cols = pid*BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, meta['BLOCK_SIZE_M'])
|
||||
mask = (rows[:, None] < M) & (cols[None, :] < N)
|
||||
offs = rows[:, None]*N + cols[None, :]
|
||||
dw += tl.load(DW + offs, mask=mask, other=0.)
|
||||
db += tl.load(DB + offs, mask=mask, other=0.)
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
tl.store(FINAL_DW + cols, sum_dw, mask=cols<N)
|
||||
tl.store(FINAL_DB + cols, sum_db, mask=cols<N)
|
||||
|
||||
class LayerNorm(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, normalized_shape, weight, bias, eps):
|
||||
# allocate output
|
||||
y = torch.empty_like(x)
|
||||
# reshape input data into 2D tensor
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
mean = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
||||
rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_SIZE:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
# enqueue kernel
|
||||
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
|
||||
x_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.eps = eps
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
x, w, b, m, v = ctx.saved_tensors
|
||||
# heuristics for amount of parallel reduction stream for DG/DB
|
||||
N = w.shape[0]
|
||||
GROUP_SIZE_M = 64
|
||||
if N <= 8192: GROUP_SIZE_M = 96
|
||||
if N <= 4096: GROUP_SIZE_M = 128
|
||||
if N <= 1024: GROUP_SIZE_M = 256
|
||||
# allocate output
|
||||
locks = torch.zeros(2*GROUP_SIZE_M, dtype=torch.int32, device='cuda')
|
||||
_dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
_db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device)
|
||||
dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device)
|
||||
dx = torch.empty_like(dy)
|
||||
# enqueue kernel using forward pass heuristics
|
||||
# also compute partial sums for DW and DB
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
_layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks,
|
||||
x_arg.stride(0), N, ctx.eps,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE,
|
||||
GROUP_SIZE_M=GROUP_SIZE_M,
|
||||
num_warps=ctx.num_warps)
|
||||
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
|
||||
# accumulate partial sums in separate kernel
|
||||
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
|
||||
BLOCK_SIZE_M = 32,
|
||||
BLOCK_SIZE_N = 128)
|
||||
return dx, None, dw, db, None
|
||||
|
||||
|
||||
layer_norm = LayerNorm.apply
|
||||
|
||||
|
||||
def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
dy = .1*torch.randn_like(x)
|
||||
x.requires_grad_(True)
|
||||
# forward pass
|
||||
y_tri = layer_norm(x, w_shape, weight, bias, eps)
|
||||
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
|
||||
# backward pass (triton)
|
||||
y_tri.backward(dy, retain_graph=True)
|
||||
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
|
||||
x.grad, weight.grad, bias.grad = None, None, None
|
||||
# backward pass (torch)
|
||||
y_ref.backward(dy, retain_graph=True)
|
||||
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(y_tri, y_ref)
|
||||
triton.testing.assert_almost_equal(dx_tri, dx_ref)
|
||||
triton.testing.assert_almost_equal(db_tri, db_ref, decimal=1)
|
||||
triton.testing.assert_almost_equal(dw_tri, dw_ref, decimal=1)
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N'],
|
||||
x_vals=[512 * i for i in range(2, 32)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton', 'torch', 'apex'],
|
||||
line_names=['Triton', 'Torch', 'Apex'],
|
||||
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
|
||||
ylabel='GB/s',
|
||||
plot_name='layer-norm-backward',
|
||||
args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}
|
||||
)
|
||||
)
|
||||
def bench_layer_norm(M, N, dtype, provider, mode='backward',eps=1e-5, device='cuda'):
|
||||
# create data
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1], )
|
||||
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=True)
|
||||
x = -2.3 + 0.5*torch.randn(x_shape, dtype=dtype, device='cuda')
|
||||
dy = .1*torch.randn_like(x)
|
||||
x.requires_grad_(True)
|
||||
# utility functions
|
||||
if provider == 'triton':
|
||||
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
|
||||
if provider == 'torch':
|
||||
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
|
||||
if provider == 'apex':
|
||||
import apex
|
||||
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
|
||||
y_fwd = lambda: apex_layer_norm(x)
|
||||
# forward pass
|
||||
if mode == 'forward':
|
||||
gbps = lambda ms: 2*x.numel()*x.element_size()/ms*1e-6
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, rep=500)
|
||||
# backward pass
|
||||
if mode == 'backward':
|
||||
gbps = lambda ms: 3*x.numel()*x.element_size()/ms*1e-6
|
||||
y = y_fwd()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
|
||||
grad_to_none=[x], rep=500)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
bench_layer_norm.run(save_path='.', print_data=True)
|
Reference in New Issue
Block a user