Compare commits
7 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
2d6df9b518 | ||
|
1b842f8e5e | ||
|
d3e584d4ba | ||
|
d35014ba47 | ||
|
5ce1b726dc | ||
|
858dec8372 | ||
|
90ded16c32 |
1
.github/workflows/integration-tests.yml
vendored
1
.github/workflows/integration-tests.yml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
- v2.0
|
||||
|
||||
|
||||
jobs:
|
||||
|
@@ -33,7 +33,7 @@ namespace codegen{
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
|
||||
codegen::target* target,
|
||||
int sm, int num_warps,
|
||||
int num_stages, bool force_nc_cache, int &shared_static);
|
||||
int num_stages, int &shared_static);
|
||||
|
||||
|
||||
}
|
||||
|
@@ -122,8 +122,7 @@ public:
|
||||
analysis::allocation *alloc,
|
||||
analysis::swizzle *swizzle,
|
||||
target *tgt,
|
||||
unsigned num_warps,
|
||||
bool force_nc_cache = false);
|
||||
unsigned num_warps);
|
||||
|
||||
void visit_value(ir::value* v);
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
@@ -148,12 +147,14 @@ public:
|
||||
void visit_store_inst(ir::store_inst*);
|
||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||
void visit_cat_inst(ir::cat_inst*);
|
||||
void visit_reshape_inst(ir::reshape_inst*);
|
||||
void visit_splat_inst(ir::splat_inst*);
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
void visit_exp_inst(ir::exp_inst*);
|
||||
void visit_cos_inst(ir::cos_inst*);
|
||||
void visit_umulhi_inst(ir::umulhi_inst* x);
|
||||
void visit_sin_inst(ir::sin_inst*);
|
||||
void visit_log_inst(ir::log_inst*);
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
@@ -213,7 +214,6 @@ private:
|
||||
std::set<ir::value*> seen_;
|
||||
|
||||
unsigned num_warps_;
|
||||
bool force_nc_cache_;
|
||||
|
||||
std::map<analysis::data_layout*, Value*> offset_a_m_;
|
||||
std::map<analysis::data_layout*, Value*> offset_a_k_;
|
||||
|
@@ -130,13 +130,14 @@ public:
|
||||
value *create_xor(value *lhs, value *rhs);
|
||||
value *create_or(value *lhs, value *rhs);
|
||||
// Input/Output
|
||||
value *create_load(value *arg);
|
||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache);
|
||||
value *create_store(value *ptr, value *val);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||
// Block instruction
|
||||
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_cat(value *lhs, value *rhs);
|
||||
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
|
||||
// Built-in instruction
|
||||
value *create_get_program_id(unsigned axis);
|
||||
@@ -153,8 +154,10 @@ public:
|
||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
||||
value *create_select(value *pred, value *if_value, value *else_value);
|
||||
// Intrinsics
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
value *create_umulhi(value* lhs, value* rhs);
|
||||
value *create_copy_to_shared(value *arg);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
|
||||
value *create_copy_from_shared(value *arg);
|
||||
value *create_barrier(const std::string &name = "");
|
||||
value *create_async_wait(int N);
|
||||
|
@@ -61,13 +61,14 @@ struct dispatch{
|
||||
|
||||
// casting ops
|
||||
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
|
||||
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
|
||||
// memory operators
|
||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder);
|
||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, ir::builder *builder);
|
||||
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
|
||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
@@ -90,6 +91,7 @@ struct dispatch{
|
||||
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
|
||||
// math
|
||||
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
|
||||
static ir::value *exp(ir::value *x, ir::builder *builder);
|
||||
static ir::value *log(ir::value *x, ir::builder *builder);
|
||||
static ir::value *cos(ir::value *x, ir::builder *builder);
|
||||
|
@@ -132,6 +132,7 @@ enum value_id_t: unsigned {
|
||||
// retile
|
||||
INST_RESHAPE,
|
||||
INST_SPLAT,
|
||||
INST_CAT,
|
||||
INST_BROADCAST,
|
||||
INST_DOWNCAST,
|
||||
// builtin
|
||||
@@ -142,6 +143,7 @@ enum value_id_t: unsigned {
|
||||
INST_ATOMIC_EXCH,
|
||||
INST_ATOMIC_RMW,
|
||||
// math
|
||||
INST_UMULHI,
|
||||
INST_EXP,
|
||||
INST_COS,
|
||||
INST_SIN,
|
||||
|
@@ -394,22 +394,38 @@ public:
|
||||
|
||||
// load
|
||||
class load_inst: public io_inst {
|
||||
public:
|
||||
enum CACHE_MODIFIER : uint32_t {
|
||||
NONE=0,
|
||||
CA,
|
||||
CG,
|
||||
};
|
||||
|
||||
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
|
||||
protected:
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops,
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
std::string get_cache_modifier_repr() const {
|
||||
if (cache_ == CA) return ".ca";
|
||||
if (cache_ == CG) return ".cg";
|
||||
return "";
|
||||
}
|
||||
CACHE_MODIFIER cache_;
|
||||
|
||||
private:
|
||||
static type *get_pointee_type(type *ty);
|
||||
|
||||
};
|
||||
|
||||
// unmasked load
|
||||
class unmasked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_load"; }
|
||||
unmasked_load_inst(value *ptr, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
|
||||
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static unmasked_load_inst* create(value *ptr,
|
||||
CACHE_MODIFIER cache,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_load_inst)
|
||||
@@ -419,8 +435,8 @@ public:
|
||||
// masked load
|
||||
class masked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load"; }
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value,
|
||||
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
@@ -429,6 +445,7 @@ public:
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_inst)
|
||||
@@ -438,8 +455,8 @@ public:
|
||||
// masked load async
|
||||
class masked_load_async_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load_async_async"; }
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); }
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
@@ -448,6 +465,7 @@ public:
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_async_inst)
|
||||
@@ -502,6 +520,21 @@ public:
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cat
|
||||
|
||||
class cat_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "cat"; }
|
||||
cat_inst(value *x, value *y, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static instruction* create(value *lhs, value *rhs,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(cat_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cat_inst)
|
||||
};
|
||||
|
||||
// retile
|
||||
|
||||
class retile_inst: public unary_inst {
|
||||
@@ -636,6 +669,17 @@ public:
|
||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class umulhi_inst: public builtin_inst {
|
||||
private:
|
||||
umulhi_inst(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "umulhi"; }
|
||||
_TRITON_DEFINE_CLONE(umulhi_inst)
|
||||
_TRITON_DEFINE_ACCEPT(umulhi_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class exp_inst: public builtin_inst {
|
||||
private:
|
||||
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -785,6 +829,7 @@ public:
|
||||
// intrinsics classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
class copy_to_shared_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
@@ -866,35 +911,6 @@ public:
|
||||
instruction *next=nullptr);
|
||||
};
|
||||
|
||||
//// On NVIDIA, implementation is such that
|
||||
//// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
||||
//// so as to enable re-association on nv_static_program_idx which is constant
|
||||
//class make_range_dyn: public instruction {
|
||||
//private:
|
||||
// make_range_dyn(type *ty, const std::string &name, instruction *next);
|
||||
// std::string repr_impl() const { return "nv_dynamic_program_idx"; }
|
||||
// _TRITON_DEFINE_CLONE(make_range_dyn)
|
||||
// _TRITON_DEFINE_ACCEPT(make_range_dyn)
|
||||
|
||||
//public:
|
||||
// static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
||||
//};
|
||||
|
||||
//class make_range_sta: public constant {
|
||||
//private:
|
||||
// make_range_sta(make_range *range);
|
||||
|
||||
//public:
|
||||
// static make_range_sta *get(make_range* range);
|
||||
// make_range* get_range() const;
|
||||
// std::string repr() const { return "nv_static_program_idx"; }
|
||||
// _TRITON_DEFINE_ACCEPT(make_range_sta)
|
||||
|
||||
//private:
|
||||
// make_range *range_;
|
||||
//};
|
||||
|
||||
|
||||
/* constant range */
|
||||
class make_range: public instruction{
|
||||
make_range(type *ty, constant_int* first, constant_int* last);
|
||||
|
@@ -45,9 +45,11 @@ class masked_store_inst;
|
||||
class retile_inst;
|
||||
class reshape_inst;
|
||||
class splat_inst;
|
||||
class cat_inst;
|
||||
class broadcast_inst;
|
||||
class downcast_inst;
|
||||
|
||||
class umulhi_inst;
|
||||
class exp_inst;
|
||||
class cos_inst;
|
||||
class sin_inst;
|
||||
@@ -122,6 +124,7 @@ public:
|
||||
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||
|
||||
virtual void visit_umulhi_inst(umulhi_inst*) = 0;
|
||||
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||
virtual void visit_cos_inst(cos_inst*) = 0;
|
||||
virtual void visit_sin_inst(sin_inst*) = 0;
|
||||
@@ -129,6 +132,7 @@ public:
|
||||
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||
virtual void visit_cat_inst(cat_inst*) = 0;
|
||||
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
||||
virtual void visit_downcast_inst(downcast_inst*) = 0;
|
||||
|
||||
@@ -150,13 +154,10 @@ public:
|
||||
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||
// virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
|
||||
|
||||
virtual void visit_function(function*) = 0;
|
||||
|
||||
// virtual void visit_make_range_sta(make_range_sta*) = 0;
|
||||
virtual void visit_undef_value(undef_value*) = 0;
|
||||
virtual void visit_constant_int(constant_int*) = 0;
|
||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||
|
@@ -116,7 +116,8 @@ void axes::update_graph(ir::instruction *i) {
|
||||
switch (i->get_id()) {
|
||||
case ir::INST_REDUCE: return update_graph_reduce(i);
|
||||
case ir::INST_RESHAPE: return update_graph_reshape(i);
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);;
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);
|
||||
case ir::INST_CAT: return update_graph_elementwise(i, true);
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
|
@@ -499,6 +499,7 @@ void layouts::run(ir::module &mod) {
|
||||
make_graph(i);
|
||||
});
|
||||
|
||||
|
||||
// connected components
|
||||
graph_.connected_components(&values_, &groups_);
|
||||
|
||||
|
@@ -25,7 +25,7 @@ namespace codegen {
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
|
||||
int cc, int num_warps, int num_stages, bool force_nc_cache, int& shared_static) {
|
||||
int cc, int num_warps, int num_stages, int& shared_static) {
|
||||
// generate llvm code
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
|
||||
@@ -46,7 +46,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::prefetch prefetch_s(target);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps, force_nc_cache);
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
|
||||
// run passes
|
||||
dce.run(ir);
|
||||
peephole.run(ir);
|
||||
|
@@ -197,9 +197,9 @@ generator::generator(analysis::axes *a_axes,
|
||||
analysis::allocation *alloc,
|
||||
analysis::swizzle *swizzle,
|
||||
target *tgt,
|
||||
unsigned num_warps, bool force_nc_cache)
|
||||
unsigned num_warps)
|
||||
: a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle),
|
||||
tgt_(tgt), num_warps_(num_warps), force_nc_cache_(force_nc_cache), add(&builder_), mul(&builder_), gep(&builder_) {
|
||||
tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) {
|
||||
|
||||
}
|
||||
|
||||
@@ -629,10 +629,9 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
// -----
|
||||
std::ostringstream asm_oss;
|
||||
asm_oss << "@$" << n_words; // predicate
|
||||
// if(force_nc_cache_)
|
||||
asm_oss << " ld.global";
|
||||
// else
|
||||
// asm_oss << " ld.global.cg";
|
||||
asm_oss << " ld.global";
|
||||
if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca";
|
||||
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
|
||||
if(n_words > 1)
|
||||
asm_oss << ".v" << n_words; // vector width
|
||||
asm_oss << ".b" << width; // word size
|
||||
@@ -775,6 +774,22 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
|
||||
visit_store_inst(x);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `cat`
|
||||
*/
|
||||
void generator::visit_cat_inst(ir::cat_inst* x) {
|
||||
auto idxs = idxs_.at(x);
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
int i = 0;
|
||||
for(size_t j = 0; j < idxs_.at(lhs).size(); j ++)
|
||||
vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]];
|
||||
for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){
|
||||
vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `reshape`
|
||||
@@ -862,6 +877,20 @@ void generator::visit_cos_inst(ir::cos_inst* x){
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `umulhi`
|
||||
*/
|
||||
void generator::visit_umulhi_inst(ir::umulhi_inst* x){
|
||||
std::vector<llvm::Type*> tys = {i32_ty, i32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(i32_ty, tys, false);
|
||||
InlineAsm *umulhi = InlineAsm::get(fn_ty, "mul.hi.u32 $0, $1, $2;", "=r,r,r", false);
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value* lhs = vals_[x->get_operand(0)][idx];
|
||||
Value* rhs = vals_[x->get_operand(1)][idx];
|
||||
vals_[x][idx] = call(umulhi, std::vector<llvm::Value*>{lhs, rhs});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `sin`
|
||||
*/
|
||||
|
@@ -11,6 +11,8 @@ namespace transform{
|
||||
|
||||
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
|
||||
std::set<ir::value*>& seen) {
|
||||
if (dynamic_cast<ir::phi_node*>(root))
|
||||
return root;
|
||||
if(!seen.insert(root).second)
|
||||
return root;
|
||||
if(!root->get_type()->is_block_ty())
|
||||
|
@@ -116,7 +116,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
|
||||
int nts = layout->nts(layout->get_order()[0]);
|
||||
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(nts*dtsize >= 4){
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier());
|
||||
copy_to_shared->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
@@ -206,7 +206,8 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
|
||||
builder.set_insert_point(select);
|
||||
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
|
||||
if_value->get_mask_operand(),
|
||||
select->get_else_value_op());
|
||||
select->get_else_value_op(),
|
||||
if_value->get_cache_modifier());
|
||||
select->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
|
@@ -111,6 +111,8 @@ struct pipeline_info_t {
|
||||
};
|
||||
|
||||
void pipeline::run(ir::module &mod) {
|
||||
if (num_stages_ <= 1)
|
||||
return;
|
||||
// *Very* conservative heuristics for pre-fetching.
|
||||
// A load instruction can be pipelined if:
|
||||
// - the pointer is a phi node that references a value
|
||||
@@ -176,7 +178,7 @@ void pipeline::run(ir::module &mod) {
|
||||
false_value = remat_false_value;
|
||||
} else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value);
|
||||
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier());
|
||||
|
||||
for (int stage = 1; stage < num_stages-1; ++stage) {
|
||||
// mask is the loop condition of the previous iteration
|
||||
@@ -191,7 +193,7 @@ void pipeline::run(ir::module &mod) {
|
||||
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value);
|
||||
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier());
|
||||
}
|
||||
|
||||
// create new phis for induction variables
|
||||
@@ -220,7 +222,7 @@ void pipeline::run(ir::module &mod) {
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier());
|
||||
|
||||
|
||||
// phi node
|
||||
@@ -255,7 +257,7 @@ void pipeline::run(ir::module &mod) {
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier());
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
@@ -266,7 +268,7 @@ void pipeline::run(ir::module &mod) {
|
||||
next_mask = builder.create_and(next_mask, remat_mask);
|
||||
false_value = remat_false_value;
|
||||
}
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier());
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||
|
@@ -178,7 +178,7 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
|
||||
ofs.close();
|
||||
std::string cmd;
|
||||
int err;
|
||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o";
|
||||
err = system(cmd.c_str());
|
||||
CUmodule ret;
|
||||
std::ifstream _cubin(_fbin, std::ios::binary );
|
||||
|
@@ -273,16 +273,16 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
|
||||
// load/store instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_load(value *ptr){
|
||||
return insert(unmasked_load_inst::create(ptr));
|
||||
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache){
|
||||
return insert(unmasked_load_inst::create(ptr, cache));
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val){
|
||||
return insert(unmasked_store_inst::create(ptr, val));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value));
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache){
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, cache));
|
||||
}
|
||||
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask){
|
||||
@@ -297,6 +297,10 @@ value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(reshape_inst::create(arg, shapes));
|
||||
}
|
||||
|
||||
value *builder::create_cat(value *lhs, value *rhs) {
|
||||
return insert(cat_inst::create(lhs, rhs));
|
||||
}
|
||||
|
||||
value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(splat_inst::create(arg, shapes));
|
||||
}
|
||||
@@ -369,6 +373,9 @@ value *builder::create_select(value *pred, value *if_value, value *else_value){
|
||||
// intrinsic instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_umulhi(value *lhs, value *rhs) {
|
||||
return insert(umulhi_inst::create(lhs, rhs));
|
||||
}
|
||||
|
||||
value *builder::create_copy_to_shared(value *arg) {
|
||||
return insert(copy_to_shared_inst::create(arg));
|
||||
@@ -379,8 +386,8 @@ value *builder::create_copy_from_shared(value *arg) {
|
||||
return insert(copy_from_shared_inst::create(arg));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value));
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache));
|
||||
}
|
||||
|
||||
value *builder::create_barrier(const std::string &name) {
|
||||
|
@@ -368,6 +368,10 @@ ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *b
|
||||
return builder->create_reshape(input, dst_shape);
|
||||
}
|
||||
|
||||
ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||
return builder->create_cat(lhs, rhs);
|
||||
}
|
||||
|
||||
ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) {
|
||||
if (!input->get_type()->is_block_ty())
|
||||
return builder->create_splat(input, shape);
|
||||
@@ -497,7 +501,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
// Memory Operators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder* builder) {
|
||||
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, ir::builder* builder) {
|
||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
|
||||
if(ptr->get_type()->is_block_ty()){
|
||||
@@ -517,8 +521,17 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir:
|
||||
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
|
||||
ptr = dispatch::cast(ptr, ptr_ty, builder);
|
||||
}
|
||||
load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default
|
||||
if (!cache_modifier.empty()) {
|
||||
if (cache_modifier == ".ca")
|
||||
cache = load_inst::CA;
|
||||
else if (cache_modifier == ".cg")
|
||||
cache = load_inst::CG;
|
||||
else
|
||||
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
|
||||
}
|
||||
if (!mask && !other)
|
||||
return builder->create_load(ptr);
|
||||
return builder->create_load(ptr, cache);
|
||||
if (!mask)
|
||||
throw std::runtime_error("`other` cannot be provided without `mask`");
|
||||
auto shape = ptr->get_type()->get_block_shapes();
|
||||
@@ -527,7 +540,7 @@ ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, ir:
|
||||
if(ptr->get_type()->is_block_ty())
|
||||
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
|
||||
}
|
||||
return builder->create_masked_load(ptr, mask, other);
|
||||
return builder->create_masked_load(ptr, mask, other, cache);
|
||||
}
|
||||
|
||||
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
|
||||
@@ -706,6 +719,11 @@ ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *build
|
||||
// Math
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) {
|
||||
binary_op_type_checking(x, y, builder);
|
||||
return builder->insert(umulhi_inst::create(x, y));
|
||||
}
|
||||
|
||||
ir::value *dispatch::exp(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_exp(x);
|
||||
}
|
||||
|
@@ -433,8 +433,8 @@ io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &n
|
||||
{ }
|
||||
|
||||
// load_inst
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next)
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache)
|
||||
{ }
|
||||
|
||||
// load
|
||||
@@ -447,41 +447,44 @@ type *load_inst::get_pointee_type(type *ty) {
|
||||
}
|
||||
|
||||
// unmasked_load
|
||||
unmasked_load_inst::unmasked_load_inst(value *ptr, const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, name, next) {
|
||||
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, name, next) {
|
||||
set_operand(0, ptr);
|
||||
}
|
||||
|
||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, const std::string &name, instruction *next) {
|
||||
return new unmasked_load_inst(ptr, name, next);
|
||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) {
|
||||
return new unmasked_load_inst(ptr, cache, name, next);
|
||||
}
|
||||
|
||||
// masked load
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value,
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD, 3, name, next) {
|
||||
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_inst(ptr, mask, false_value, name, next);
|
||||
return new masked_load_inst(ptr, mask, false_value, cache, name, next);
|
||||
}
|
||||
|
||||
// masked load async
|
||||
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, name, next) {
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_async_inst(ptr, mask, false_value, name, next);
|
||||
return new masked_load_async_inst(ptr, mask, false_value, cache, name, next);
|
||||
}
|
||||
|
||||
// store
|
||||
@@ -519,11 +522,28 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cat
|
||||
|
||||
cat_inst::cat_inst(value *x, value *y, const std::string &name, instruction *next)
|
||||
: instruction(block_type::get(x->get_type()->get_scalar_ty(),
|
||||
{x->get_type()->get_block_shapes()[0] +
|
||||
y->get_type()->get_block_shapes()[0] }), INST_CAT, 2, name, next) {
|
||||
set_operand(0, x);
|
||||
set_operand(1, y);
|
||||
}
|
||||
|
||||
instruction* cat_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||
return new cat_inst(lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
// retile
|
||||
|
||||
retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes,
|
||||
const std::string &name, instruction *next)
|
||||
: unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
|
||||
|
||||
|
||||
|
||||
// reshape
|
||||
|
||||
instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes,
|
||||
@@ -758,6 +778,19 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
|
||||
}
|
||||
|
||||
|
||||
// umulhi
|
||||
|
||||
umulhi_inst::umulhi_inst(value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: builtin_inst(lhs->get_type(), INST_UMULHI, 2, name, next) {
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
|
||||
instruction* umulhi_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||
return new umulhi_inst(lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
|
||||
// exp
|
||||
|
||||
exp_inst::exp_inst(value *val, const std::string &name, instruction *next)
|
||||
@@ -874,7 +907,7 @@ make_range::make_range(type *ty, constant_int *first, constant_int *last)
|
||||
make_range *make_range::create(constant_int *first, constant_int *last) {
|
||||
assert(first->get_type()->is_integer_ty());
|
||||
assert(first->get_type() == last->get_type());
|
||||
assert(((constant_int*)first)->get_value() == 0);
|
||||
// assert(((constant_int*)first)->get_value() == 0);
|
||||
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()});
|
||||
return new make_range(ty, first, last);
|
||||
}
|
||||
|
@@ -121,7 +121,7 @@ class CMakeBuild(build_ext):
|
||||
|
||||
setup(
|
||||
name="triton",
|
||||
version="1.1.1",
|
||||
version="1.1.2",
|
||||
author="Philippe Tillet",
|
||||
author_email="phil@openai.com",
|
||||
description="A language and compiler for custom Deep Learning operations",
|
||||
|
@@ -203,7 +203,7 @@ std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_
|
||||
// CUDA
|
||||
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
|
||||
uint64_t device, int num_warps, int num_stages,
|
||||
bool force_nc_cache, asm_map_t &asm_map){
|
||||
asm_map_t &asm_map){
|
||||
llvm::LLVMContext ctx;
|
||||
// device properties
|
||||
CUdevice dev = (CUdevice)device;
|
||||
@@ -215,7 +215,7 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
triton::codegen::nvidia_cu_target target(cc);
|
||||
int n_shared_bytes;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
|
||||
std::string tmp;
|
||||
llvm::raw_string_ostream llir(tmp);
|
||||
llir << *llvm;
|
||||
@@ -236,12 +236,12 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
|
||||
// HIP
|
||||
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir,
|
||||
uint64_t device, int num_warps, int num_stages,
|
||||
bool force_nc_cache, asm_map_t &asm_map){
|
||||
asm_map_t &asm_map){
|
||||
llvm::LLVMContext ctx;
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
triton::codegen::amd_cl_target target;
|
||||
int n_shared_bytes;
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, force_nc_cache, n_shared_bytes);
|
||||
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes);
|
||||
std::string tmp;
|
||||
llvm::raw_string_ostream llir(tmp);
|
||||
llir << *llvm;
|
||||
@@ -255,7 +255,7 @@ std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def(
|
||||
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, bool force_nc_cache) {
|
||||
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) {
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
// record asm as we generate
|
||||
asm_map_t asm_map;
|
||||
@@ -264,9 +264,9 @@ void init_triton_codegen(py::module &&m) {
|
||||
asm_map["ttir"] = py::cast(ttir.str());
|
||||
llvm::LLVMContext ctx;
|
||||
if(backend == CUDA)
|
||||
return cu_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
|
||||
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
||||
if(backend == ROCM)
|
||||
return hip_compile_ttir(name, ir, device, num_warps, num_stages, force_nc_cache, asm_map);
|
||||
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
|
||||
}, py::return_value_policy::take_ownership);
|
||||
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
||||
if(backend == CUDA)
|
||||
@@ -313,6 +313,7 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("arange", &ir::dispatch::arange, ret::reference);
|
||||
m.def("zeros", &ir::dispatch::zeros, ret::reference);
|
||||
// type manipuatation
|
||||
m.def("cat", &ir::dispatch::cat, ret::reference);
|
||||
m.def("reshape", &ir::dispatch::reshape, ret::reference);
|
||||
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
|
||||
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
|
||||
@@ -340,6 +341,7 @@ void init_triton_frontend(py::module &&m) {
|
||||
m.def("max", &ir::dispatch::max, ret::reference);
|
||||
m.def("sum", &ir::dispatch::sum, ret::reference);
|
||||
// math
|
||||
m.def("umulhi", &ir::dispatch::umulhi, ret::reference);
|
||||
m.def("exp", &ir::dispatch::exp, ret::reference);
|
||||
m.def("log", &ir::dispatch::log, ret::reference);
|
||||
m.def("cos", &ir::dispatch::cos, ret::reference);
|
||||
|
@@ -599,6 +599,30 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
reference_out =torch.matmul(in1, in2)
|
||||
triton.testing.allclose(out, reference_out)
|
||||
|
||||
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
|
||||
def test_load_cache_modifier(cache):
|
||||
src = torch.empty(128, device='cuda')
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst, src, **meta):
|
||||
offsets = tl.arange(0, 128)
|
||||
x = tl.load(src+offsets, cache_modifier=meta['CACHE'])
|
||||
tl.store(dst+offsets, x)
|
||||
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
|
||||
if cache == '':
|
||||
assert 'ld.global.ca' not in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'ld.global.cg' in ptx
|
||||
assert 'ld.global.ca' not in ptx
|
||||
if cache == '.ca':
|
||||
assert 'ld.global.ca' in ptx
|
||||
assert 'ld.global.cg' not in ptx
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
# ---------------
|
||||
|
@@ -537,7 +537,7 @@ class Kernel:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, force_nc_cache, **meta):
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
@@ -560,13 +560,13 @@ class Kernel:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
else:
|
||||
backend = _triton.runtime.backend.ROCM
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
if shared_mem > max_shared_memory:
|
||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
|
||||
# device inference
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
if len(tensor_idxs) == 0:
|
||||
@@ -643,7 +643,7 @@ class Kernel:
|
||||
if binary is None:
|
||||
binary = self._compile(
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
constants=constants, **meta
|
||||
)
|
||||
if bin_cache_path:
|
||||
|
@@ -346,6 +346,18 @@ def broadcast_to(input, shape, _builder=None):
|
||||
"""
|
||||
return frontend.broadcast_to(input, shape, _builder)
|
||||
|
||||
@builtin
|
||||
def cat(input, other, _builder=None):
|
||||
"""
|
||||
Concatenate the given blocks
|
||||
|
||||
:param input: The first input block.
|
||||
:type input:
|
||||
:param other: The second input block.
|
||||
:type other:
|
||||
"""
|
||||
return frontend.cat(input, other, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def reshape(input, shape, _builder=None):
|
||||
@@ -387,7 +399,7 @@ def dot(input, other, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def load(pointer, mask=None, other=None, _builder=None):
|
||||
def load(pointer, mask=None, other=None, cache_modifier="", _builder=None):
|
||||
"""
|
||||
Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`.
|
||||
|
||||
@@ -401,8 +413,10 @@ def load(pointer, mask=None, other=None, _builder=None):
|
||||
:type mask: Block of triton.int1, optional
|
||||
:param other: if mask[idx] is false, return other[idx]
|
||||
:type other: Block, optional
|
||||
:param cache_modifier: changes cache option in nvidia ptx
|
||||
'type cache_modifier: str, optional
|
||||
"""
|
||||
return frontend.load(pointer, mask, other, _builder)
|
||||
return frontend.load(pointer, mask, other, cache_modifier, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@@ -522,6 +536,10 @@ def where(condition, x, y, _builder=None):
|
||||
# Math
|
||||
# -----------------------
|
||||
|
||||
@builtin
|
||||
def umulhi(x, y, _builder=None):
|
||||
return frontend.umulhi(x, y, _builder)
|
||||
|
||||
def _add_math_1arg_docstr(name):
|
||||
|
||||
def _decorator(func):
|
||||
@@ -541,7 +559,6 @@ def _add_math_1arg_docstr(name):
|
||||
def exp(x, _builder=None):
|
||||
return frontend.exp(x, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_math_1arg_docstr("natural logarithm")
|
||||
def log(x, _builder=None):
|
||||
|
@@ -31,42 +31,26 @@ def PHILOX_ROUND_B():
|
||||
# 0xCD9E8D57
|
||||
return -845247145
|
||||
|
||||
|
||||
@triton.jit
|
||||
def hacky_to_uint64(x):
|
||||
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def multiply_low_high(a, b):
|
||||
return (
|
||||
a * b,
|
||||
((hacky_to_uint64(a) * hacky_to_uint64(b)) >> 32).to(tl.int32)
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def single_round(c0, c1, c2, c3, k0, k1):
|
||||
A = PHILOX_ROUND_A()
|
||||
B = PHILOX_ROUND_B()
|
||||
lo0, hi0 = multiply_low_high(A, c0)
|
||||
lo1, hi1 = multiply_low_high(B, c2)
|
||||
|
||||
return (
|
||||
hi1 ^ c1 ^ k0,
|
||||
lo1,
|
||||
hi0 ^ c3 ^ k1,
|
||||
lo0,
|
||||
)
|
||||
_c0, _c2 = c0, c2
|
||||
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
|
||||
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
|
||||
c1 = B * _c2
|
||||
c3 = A * _c0
|
||||
return c0, c1, c2, c3
|
||||
|
||||
|
||||
@triton.jit
|
||||
def raise_key(k0, k1):
|
||||
return (
|
||||
k0 + PHILOX_KEY_A(),
|
||||
k1 + PHILOX_KEY_B(),
|
||||
)
|
||||
|
||||
return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B())
|
||||
|
||||
@triton.jit
|
||||
def philox_f(c0, c1, c2, c3, k0, k1):
|
||||
@@ -125,7 +109,7 @@ def randint4x(seed, offset):
|
||||
:param seed: The seed for generating random numbers.
|
||||
:param offsets: The offsets to generate random numbers for.
|
||||
"""
|
||||
z = 0
|
||||
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting
|
||||
seed = hacky_to_uint64(seed) # uint will solve this
|
||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
||||
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
||||
|
@@ -1,3 +1,8 @@
|
||||
"""
|
||||
Layer Normalization
|
||||
====================
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
import triton
|
||||
|
Reference in New Issue
Block a user