[CODEGEN] Various bugfixes that make it possible to fuse RNG in a matmul epilogue (#356)
This commit is contained in:
@@ -147,12 +147,14 @@ public:
|
|||||||
void visit_store_inst(ir::store_inst*);
|
void visit_store_inst(ir::store_inst*);
|
||||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||||
|
void visit_cat_inst(ir::cat_inst*);
|
||||||
void visit_reshape_inst(ir::reshape_inst*);
|
void visit_reshape_inst(ir::reshape_inst*);
|
||||||
void visit_splat_inst(ir::splat_inst*);
|
void visit_splat_inst(ir::splat_inst*);
|
||||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||||
void visit_downcast_inst(ir::downcast_inst*);
|
void visit_downcast_inst(ir::downcast_inst*);
|
||||||
void visit_exp_inst(ir::exp_inst*);
|
void visit_exp_inst(ir::exp_inst*);
|
||||||
void visit_cos_inst(ir::cos_inst*);
|
void visit_cos_inst(ir::cos_inst*);
|
||||||
|
void visit_umulhi_inst(ir::umulhi_inst* x);
|
||||||
void visit_sin_inst(ir::sin_inst*);
|
void visit_sin_inst(ir::sin_inst*);
|
||||||
void visit_log_inst(ir::log_inst*);
|
void visit_log_inst(ir::log_inst*);
|
||||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||||
|
@@ -137,6 +137,7 @@ public:
|
|||||||
// Block instruction
|
// Block instruction
|
||||||
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
||||||
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
||||||
|
value *create_cat(value *lhs, value *rhs);
|
||||||
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
|
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
|
||||||
// Built-in instruction
|
// Built-in instruction
|
||||||
value *create_get_program_id(unsigned axis);
|
value *create_get_program_id(unsigned axis);
|
||||||
@@ -153,6 +154,8 @@ public:
|
|||||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
||||||
value *create_select(value *pred, value *if_value, value *else_value);
|
value *create_select(value *pred, value *if_value, value *else_value);
|
||||||
// Intrinsics
|
// Intrinsics
|
||||||
|
// These have no place in the IR, and hopefully they can be removed at some point
|
||||||
|
value *create_umulhi(value* lhs, value* rhs);
|
||||||
value *create_copy_to_shared(value *arg);
|
value *create_copy_to_shared(value *arg);
|
||||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
|
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
|
||||||
value *create_copy_from_shared(value *arg);
|
value *create_copy_from_shared(value *arg);
|
||||||
|
@@ -61,6 +61,7 @@ struct dispatch{
|
|||||||
|
|
||||||
// casting ops
|
// casting ops
|
||||||
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
|
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
|
||||||
|
static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||||
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
|
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
|
||||||
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
|
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
|
||||||
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
|
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||||
@@ -90,6 +91,7 @@ struct dispatch{
|
|||||||
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
|
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||||
|
|
||||||
// math
|
// math
|
||||||
|
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
|
||||||
static ir::value *exp(ir::value *x, ir::builder *builder);
|
static ir::value *exp(ir::value *x, ir::builder *builder);
|
||||||
static ir::value *log(ir::value *x, ir::builder *builder);
|
static ir::value *log(ir::value *x, ir::builder *builder);
|
||||||
static ir::value *cos(ir::value *x, ir::builder *builder);
|
static ir::value *cos(ir::value *x, ir::builder *builder);
|
||||||
|
@@ -132,6 +132,7 @@ enum value_id_t: unsigned {
|
|||||||
// retile
|
// retile
|
||||||
INST_RESHAPE,
|
INST_RESHAPE,
|
||||||
INST_SPLAT,
|
INST_SPLAT,
|
||||||
|
INST_CAT,
|
||||||
INST_BROADCAST,
|
INST_BROADCAST,
|
||||||
INST_DOWNCAST,
|
INST_DOWNCAST,
|
||||||
// builtin
|
// builtin
|
||||||
@@ -142,6 +143,7 @@ enum value_id_t: unsigned {
|
|||||||
INST_ATOMIC_EXCH,
|
INST_ATOMIC_EXCH,
|
||||||
INST_ATOMIC_RMW,
|
INST_ATOMIC_RMW,
|
||||||
// math
|
// math
|
||||||
|
INST_UMULHI,
|
||||||
INST_EXP,
|
INST_EXP,
|
||||||
INST_COS,
|
INST_COS,
|
||||||
INST_SIN,
|
INST_SIN,
|
||||||
|
@@ -520,6 +520,21 @@ public:
|
|||||||
// retile_inst classes
|
// retile_inst classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// cat
|
||||||
|
|
||||||
|
class cat_inst: public instruction {
|
||||||
|
private:
|
||||||
|
std::string repr_impl() const { return "cat"; }
|
||||||
|
cat_inst(value *x, value *y, const std::string &name, instruction *next);
|
||||||
|
|
||||||
|
public:
|
||||||
|
static instruction* create(value *lhs, value *rhs,
|
||||||
|
const std::string &name = "",
|
||||||
|
instruction *next = nullptr);
|
||||||
|
_TRITON_DEFINE_CLONE(cat_inst)
|
||||||
|
_TRITON_DEFINE_ACCEPT(cat_inst)
|
||||||
|
};
|
||||||
|
|
||||||
// retile
|
// retile
|
||||||
|
|
||||||
class retile_inst: public unary_inst {
|
class retile_inst: public unary_inst {
|
||||||
@@ -654,6 +669,17 @@ public:
|
|||||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class umulhi_inst: public builtin_inst {
|
||||||
|
private:
|
||||||
|
umulhi_inst(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
std::string repr_impl() const { return "umulhi"; }
|
||||||
|
_TRITON_DEFINE_CLONE(umulhi_inst)
|
||||||
|
_TRITON_DEFINE_ACCEPT(umulhi_inst)
|
||||||
|
|
||||||
|
public:
|
||||||
|
static instruction* create(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
class exp_inst: public builtin_inst {
|
class exp_inst: public builtin_inst {
|
||||||
private:
|
private:
|
||||||
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||||
@@ -803,6 +829,7 @@ public:
|
|||||||
// intrinsics classes
|
// intrinsics classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
class copy_to_shared_inst: public unary_inst{
|
class copy_to_shared_inst: public unary_inst{
|
||||||
private:
|
private:
|
||||||
using unary_inst::unary_inst;
|
using unary_inst::unary_inst;
|
||||||
@@ -884,35 +911,6 @@ public:
|
|||||||
instruction *next=nullptr);
|
instruction *next=nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
//// On NVIDIA, implementation is such that
|
|
||||||
//// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
|
||||||
//// so as to enable re-association on nv_static_program_idx which is constant
|
|
||||||
//class make_range_dyn: public instruction {
|
|
||||||
//private:
|
|
||||||
// make_range_dyn(type *ty, const std::string &name, instruction *next);
|
|
||||||
// std::string repr_impl() const { return "nv_dynamic_program_idx"; }
|
|
||||||
// _TRITON_DEFINE_CLONE(make_range_dyn)
|
|
||||||
// _TRITON_DEFINE_ACCEPT(make_range_dyn)
|
|
||||||
|
|
||||||
//public:
|
|
||||||
// static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
|
||||||
//};
|
|
||||||
|
|
||||||
//class make_range_sta: public constant {
|
|
||||||
//private:
|
|
||||||
// make_range_sta(make_range *range);
|
|
||||||
|
|
||||||
//public:
|
|
||||||
// static make_range_sta *get(make_range* range);
|
|
||||||
// make_range* get_range() const;
|
|
||||||
// std::string repr() const { return "nv_static_program_idx"; }
|
|
||||||
// _TRITON_DEFINE_ACCEPT(make_range_sta)
|
|
||||||
|
|
||||||
//private:
|
|
||||||
// make_range *range_;
|
|
||||||
//};
|
|
||||||
|
|
||||||
|
|
||||||
/* constant range */
|
/* constant range */
|
||||||
class make_range: public instruction{
|
class make_range: public instruction{
|
||||||
make_range(type *ty, constant_int* first, constant_int* last);
|
make_range(type *ty, constant_int* first, constant_int* last);
|
||||||
|
@@ -45,9 +45,11 @@ class masked_store_inst;
|
|||||||
class retile_inst;
|
class retile_inst;
|
||||||
class reshape_inst;
|
class reshape_inst;
|
||||||
class splat_inst;
|
class splat_inst;
|
||||||
|
class cat_inst;
|
||||||
class broadcast_inst;
|
class broadcast_inst;
|
||||||
class downcast_inst;
|
class downcast_inst;
|
||||||
|
|
||||||
|
class umulhi_inst;
|
||||||
class exp_inst;
|
class exp_inst;
|
||||||
class cos_inst;
|
class cos_inst;
|
||||||
class sin_inst;
|
class sin_inst;
|
||||||
@@ -122,6 +124,7 @@ public:
|
|||||||
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
||||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||||
|
|
||||||
|
virtual void visit_umulhi_inst(umulhi_inst*) = 0;
|
||||||
virtual void visit_exp_inst(exp_inst*) = 0;
|
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||||
virtual void visit_cos_inst(cos_inst*) = 0;
|
virtual void visit_cos_inst(cos_inst*) = 0;
|
||||||
virtual void visit_sin_inst(sin_inst*) = 0;
|
virtual void visit_sin_inst(sin_inst*) = 0;
|
||||||
@@ -129,6 +132,7 @@ public:
|
|||||||
|
|
||||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||||
|
virtual void visit_cat_inst(cat_inst*) = 0;
|
||||||
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
||||||
virtual void visit_downcast_inst(downcast_inst*) = 0;
|
virtual void visit_downcast_inst(downcast_inst*) = 0;
|
||||||
|
|
||||||
@@ -150,13 +154,10 @@ public:
|
|||||||
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
||||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||||
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||||
// virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
|
||||||
virtual void visit_make_range(make_range*) = 0;
|
virtual void visit_make_range(make_range*) = 0;
|
||||||
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
|
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
|
||||||
|
|
||||||
virtual void visit_function(function*) = 0;
|
virtual void visit_function(function*) = 0;
|
||||||
|
|
||||||
// virtual void visit_make_range_sta(make_range_sta*) = 0;
|
|
||||||
virtual void visit_undef_value(undef_value*) = 0;
|
virtual void visit_undef_value(undef_value*) = 0;
|
||||||
virtual void visit_constant_int(constant_int*) = 0;
|
virtual void visit_constant_int(constant_int*) = 0;
|
||||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||||
|
@@ -116,7 +116,8 @@ void axes::update_graph(ir::instruction *i) {
|
|||||||
switch (i->get_id()) {
|
switch (i->get_id()) {
|
||||||
case ir::INST_REDUCE: return update_graph_reduce(i);
|
case ir::INST_REDUCE: return update_graph_reduce(i);
|
||||||
case ir::INST_RESHAPE: return update_graph_reshape(i);
|
case ir::INST_RESHAPE: return update_graph_reshape(i);
|
||||||
case ir::INST_SPLAT: return update_graph_no_edge(i);;
|
case ir::INST_SPLAT: return update_graph_no_edge(i);
|
||||||
|
case ir::INST_CAT: return update_graph_elementwise(i, true);
|
||||||
case ir::INST_TRANS: return update_graph_trans(i);
|
case ir::INST_TRANS: return update_graph_trans(i);
|
||||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||||
case ir::INST_DOT: return update_graph_dot(i);
|
case ir::INST_DOT: return update_graph_dot(i);
|
||||||
|
@@ -499,6 +499,7 @@ void layouts::run(ir::module &mod) {
|
|||||||
make_graph(i);
|
make_graph(i);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
// connected components
|
// connected components
|
||||||
graph_.connected_components(&values_, &groups_);
|
graph_.connected_components(&values_, &groups_);
|
||||||
|
|
||||||
|
@@ -774,6 +774,22 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
|
|||||||
visit_store_inst(x);
|
visit_store_inst(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Code Generation for `cat`
|
||||||
|
*/
|
||||||
|
void generator::visit_cat_inst(ir::cat_inst* x) {
|
||||||
|
auto idxs = idxs_.at(x);
|
||||||
|
ir::value* lhs = x->get_operand(0);
|
||||||
|
ir::value* rhs = x->get_operand(1);
|
||||||
|
int i = 0;
|
||||||
|
for(size_t j = 0; j < idxs_.at(lhs).size(); j ++)
|
||||||
|
vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]];
|
||||||
|
for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){
|
||||||
|
vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Code Generation for `reshape`
|
* \brief Code Generation for `reshape`
|
||||||
@@ -861,6 +877,20 @@ void generator::visit_cos_inst(ir::cos_inst* x){
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Code Generation for `umulhi`
|
||||||
|
*/
|
||||||
|
void generator::visit_umulhi_inst(ir::umulhi_inst* x){
|
||||||
|
std::vector<llvm::Type*> tys = {i32_ty, i32_ty};
|
||||||
|
FunctionType *fn_ty = FunctionType::get(i32_ty, tys, false);
|
||||||
|
InlineAsm *umulhi = InlineAsm::get(fn_ty, "mul.hi.u32 $0, $1, $2;", "=r,r,r", false);
|
||||||
|
for(auto idx: idxs_.at(x)){
|
||||||
|
Value* lhs = vals_[x->get_operand(0)][idx];
|
||||||
|
Value* rhs = vals_[x->get_operand(1)][idx];
|
||||||
|
vals_[x][idx] = call(umulhi, std::vector<llvm::Value*>{lhs, rhs});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Code Generation for `sin`
|
* \brief Code Generation for `sin`
|
||||||
*/
|
*/
|
||||||
|
@@ -11,6 +11,8 @@ namespace transform{
|
|||||||
|
|
||||||
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
|
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
|
||||||
std::set<ir::value*>& seen) {
|
std::set<ir::value*>& seen) {
|
||||||
|
if (dynamic_cast<ir::phi_node*>(root))
|
||||||
|
return root;
|
||||||
if(!seen.insert(root).second)
|
if(!seen.insert(root).second)
|
||||||
return root;
|
return root;
|
||||||
if(!root->get_type()->is_block_ty())
|
if(!root->get_type()->is_block_ty())
|
||||||
|
@@ -178,7 +178,7 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
|
|||||||
ofs.close();
|
ofs.close();
|
||||||
std::string cmd;
|
std::string cmd;
|
||||||
int err;
|
int err;
|
||||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o";
|
||||||
err = system(cmd.c_str());
|
err = system(cmd.c_str());
|
||||||
CUmodule ret;
|
CUmodule ret;
|
||||||
std::ifstream _cubin(_fbin, std::ios::binary );
|
std::ifstream _cubin(_fbin, std::ios::binary );
|
||||||
|
@@ -297,6 +297,10 @@ value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) {
|
|||||||
return insert(reshape_inst::create(arg, shapes));
|
return insert(reshape_inst::create(arg, shapes));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
value *builder::create_cat(value *lhs, value *rhs) {
|
||||||
|
return insert(cat_inst::create(lhs, rhs));
|
||||||
|
}
|
||||||
|
|
||||||
value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) {
|
value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) {
|
||||||
return insert(splat_inst::create(arg, shapes));
|
return insert(splat_inst::create(arg, shapes));
|
||||||
}
|
}
|
||||||
@@ -369,6 +373,9 @@ value *builder::create_select(value *pred, value *if_value, value *else_value){
|
|||||||
// intrinsic instructions
|
// intrinsic instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
value *builder::create_umulhi(value *lhs, value *rhs) {
|
||||||
|
return insert(umulhi_inst::create(lhs, rhs));
|
||||||
|
}
|
||||||
|
|
||||||
value *builder::create_copy_to_shared(value *arg) {
|
value *builder::create_copy_to_shared(value *arg) {
|
||||||
return insert(copy_to_shared_inst::create(arg));
|
return insert(copy_to_shared_inst::create(arg));
|
||||||
|
@@ -368,6 +368,10 @@ ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *b
|
|||||||
return builder->create_reshape(input, dst_shape);
|
return builder->create_reshape(input, dst_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||||
|
return builder->create_cat(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) {
|
ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) {
|
||||||
if (!input->get_type()->is_block_ty())
|
if (!input->get_type()->is_block_ty())
|
||||||
return builder->create_splat(input, shape);
|
return builder->create_splat(input, shape);
|
||||||
@@ -715,6 +719,11 @@ ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *build
|
|||||||
// Math
|
// Math
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) {
|
||||||
|
binary_op_type_checking(x, y, builder);
|
||||||
|
return builder->insert(umulhi_inst::create(x, y));
|
||||||
|
}
|
||||||
|
|
||||||
ir::value *dispatch::exp(ir::value *x, ir::builder *builder) {
|
ir::value *dispatch::exp(ir::value *x, ir::builder *builder) {
|
||||||
return builder->create_exp(x);
|
return builder->create_exp(x);
|
||||||
}
|
}
|
||||||
|
@@ -522,11 +522,28 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask
|
|||||||
// retile_inst classes
|
// retile_inst classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// cat
|
||||||
|
|
||||||
|
cat_inst::cat_inst(value *x, value *y, const std::string &name, instruction *next)
|
||||||
|
: instruction(block_type::get(x->get_type()->get_scalar_ty(),
|
||||||
|
{x->get_type()->get_block_shapes()[0] +
|
||||||
|
y->get_type()->get_block_shapes()[0] }), INST_CAT, 2, name, next) {
|
||||||
|
set_operand(0, x);
|
||||||
|
set_operand(1, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
instruction* cat_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||||
|
return new cat_inst(lhs, rhs, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
|
// retile
|
||||||
|
|
||||||
retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes,
|
retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes,
|
||||||
const std::string &name, instruction *next)
|
const std::string &name, instruction *next)
|
||||||
: unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
|
: unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// reshape
|
// reshape
|
||||||
|
|
||||||
instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes,
|
instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes,
|
||||||
@@ -761,6 +778,19 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// umulhi
|
||||||
|
|
||||||
|
umulhi_inst::umulhi_inst(value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||||
|
: builtin_inst(lhs->get_type(), INST_UMULHI, 2, name, next) {
|
||||||
|
set_operand(0, lhs);
|
||||||
|
set_operand(1, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
instruction* umulhi_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||||
|
return new umulhi_inst(lhs, rhs, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// exp
|
// exp
|
||||||
|
|
||||||
exp_inst::exp_inst(value *val, const std::string &name, instruction *next)
|
exp_inst::exp_inst(value *val, const std::string &name, instruction *next)
|
||||||
@@ -877,7 +907,7 @@ make_range::make_range(type *ty, constant_int *first, constant_int *last)
|
|||||||
make_range *make_range::create(constant_int *first, constant_int *last) {
|
make_range *make_range::create(constant_int *first, constant_int *last) {
|
||||||
assert(first->get_type()->is_integer_ty());
|
assert(first->get_type()->is_integer_ty());
|
||||||
assert(first->get_type() == last->get_type());
|
assert(first->get_type() == last->get_type());
|
||||||
assert(((constant_int*)first)->get_value() == 0);
|
// assert(((constant_int*)first)->get_value() == 0);
|
||||||
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()});
|
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()});
|
||||||
return new make_range(ty, first, last);
|
return new make_range(ty, first, last);
|
||||||
}
|
}
|
||||||
|
@@ -313,6 +313,7 @@ void init_triton_frontend(py::module &&m) {
|
|||||||
m.def("arange", &ir::dispatch::arange, ret::reference);
|
m.def("arange", &ir::dispatch::arange, ret::reference);
|
||||||
m.def("zeros", &ir::dispatch::zeros, ret::reference);
|
m.def("zeros", &ir::dispatch::zeros, ret::reference);
|
||||||
// type manipuatation
|
// type manipuatation
|
||||||
|
m.def("cat", &ir::dispatch::cat, ret::reference);
|
||||||
m.def("reshape", &ir::dispatch::reshape, ret::reference);
|
m.def("reshape", &ir::dispatch::reshape, ret::reference);
|
||||||
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
|
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
|
||||||
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
|
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
|
||||||
@@ -340,6 +341,7 @@ void init_triton_frontend(py::module &&m) {
|
|||||||
m.def("max", &ir::dispatch::max, ret::reference);
|
m.def("max", &ir::dispatch::max, ret::reference);
|
||||||
m.def("sum", &ir::dispatch::sum, ret::reference);
|
m.def("sum", &ir::dispatch::sum, ret::reference);
|
||||||
// math
|
// math
|
||||||
|
m.def("umulhi", &ir::dispatch::umulhi, ret::reference);
|
||||||
m.def("exp", &ir::dispatch::exp, ret::reference);
|
m.def("exp", &ir::dispatch::exp, ret::reference);
|
||||||
m.def("log", &ir::dispatch::log, ret::reference);
|
m.def("log", &ir::dispatch::log, ret::reference);
|
||||||
m.def("cos", &ir::dispatch::cos, ret::reference);
|
m.def("cos", &ir::dispatch::cos, ret::reference);
|
||||||
|
@@ -346,6 +346,18 @@ def broadcast_to(input, shape, _builder=None):
|
|||||||
"""
|
"""
|
||||||
return frontend.broadcast_to(input, shape, _builder)
|
return frontend.broadcast_to(input, shape, _builder)
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def cat(input, other, _builder=None):
|
||||||
|
"""
|
||||||
|
Concatenate the given blocks
|
||||||
|
|
||||||
|
:param input: The first input block.
|
||||||
|
:type input:
|
||||||
|
:param other: The second input block.
|
||||||
|
:type other:
|
||||||
|
"""
|
||||||
|
return frontend.cat(input, other, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def reshape(input, shape, _builder=None):
|
def reshape(input, shape, _builder=None):
|
||||||
@@ -524,6 +536,10 @@ def where(condition, x, y, _builder=None):
|
|||||||
# Math
|
# Math
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def umulhi(x, y, _builder=None):
|
||||||
|
return frontend.umulhi(x, y, _builder)
|
||||||
|
|
||||||
def _add_math_1arg_docstr(name):
|
def _add_math_1arg_docstr(name):
|
||||||
|
|
||||||
def _decorator(func):
|
def _decorator(func):
|
||||||
@@ -543,7 +559,6 @@ def _add_math_1arg_docstr(name):
|
|||||||
def exp(x, _builder=None):
|
def exp(x, _builder=None):
|
||||||
return frontend.exp(x, _builder)
|
return frontend.exp(x, _builder)
|
||||||
|
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
@_add_math_1arg_docstr("natural logarithm")
|
@_add_math_1arg_docstr("natural logarithm")
|
||||||
def log(x, _builder=None):
|
def log(x, _builder=None):
|
||||||
|
@@ -31,42 +31,26 @@ def PHILOX_ROUND_B():
|
|||||||
# 0xCD9E8D57
|
# 0xCD9E8D57
|
||||||
return -845247145
|
return -845247145
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def hacky_to_uint64(x):
|
def hacky_to_uint64(x):
|
||||||
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
|
return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def multiply_low_high(a, b):
|
|
||||||
return (
|
|
||||||
a * b,
|
|
||||||
((hacky_to_uint64(a) * hacky_to_uint64(b)) >> 32).to(tl.int32)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def single_round(c0, c1, c2, c3, k0, k1):
|
def single_round(c0, c1, c2, c3, k0, k1):
|
||||||
A = PHILOX_ROUND_A()
|
A = PHILOX_ROUND_A()
|
||||||
B = PHILOX_ROUND_B()
|
B = PHILOX_ROUND_B()
|
||||||
lo0, hi0 = multiply_low_high(A, c0)
|
_c0, _c2 = c0, c2
|
||||||
lo1, hi1 = multiply_low_high(B, c2)
|
c0 = tl.umulhi(B, _c2) ^ c1 ^ k0
|
||||||
|
c2 = tl.umulhi(A, _c0) ^ c3 ^ k1
|
||||||
return (
|
c1 = B * _c2
|
||||||
hi1 ^ c1 ^ k0,
|
c3 = A * _c0
|
||||||
lo1,
|
return c0, c1, c2, c3
|
||||||
hi0 ^ c3 ^ k1,
|
|
||||||
lo0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def raise_key(k0, k1):
|
def raise_key(k0, k1):
|
||||||
return (
|
return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B())
|
||||||
k0 + PHILOX_KEY_A(),
|
|
||||||
k1 + PHILOX_KEY_B(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def philox_f(c0, c1, c2, c3, k0, k1):
|
def philox_f(c0, c1, c2, c3, k0, k1):
|
||||||
@@ -125,7 +109,7 @@ def randint4x(seed, offset):
|
|||||||
:param seed: The seed for generating random numbers.
|
:param seed: The seed for generating random numbers.
|
||||||
:param offsets: The offsets to generate random numbers for.
|
:param offsets: The offsets to generate random numbers for.
|
||||||
"""
|
"""
|
||||||
z = 0
|
z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting
|
||||||
seed = hacky_to_uint64(seed) # uint will solve this
|
seed = hacky_to_uint64(seed) # uint will solve this
|
||||||
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32)
|
||||||
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
seed_lo = (seed & 0xffffffff).to(tl.int32)
|
||||||
|
Reference in New Issue
Block a user