[CODEGEN] Various bugfixes that make it possible to fuse RNG in a matmul epilogue (#356)
This commit is contained in:
@@ -116,7 +116,8 @@ void axes::update_graph(ir::instruction *i) {
|
||||
switch (i->get_id()) {
|
||||
case ir::INST_REDUCE: return update_graph_reduce(i);
|
||||
case ir::INST_RESHAPE: return update_graph_reshape(i);
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);;
|
||||
case ir::INST_SPLAT: return update_graph_no_edge(i);
|
||||
case ir::INST_CAT: return update_graph_elementwise(i, true);
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
|
@@ -499,6 +499,7 @@ void layouts::run(ir::module &mod) {
|
||||
make_graph(i);
|
||||
});
|
||||
|
||||
|
||||
// connected components
|
||||
graph_.connected_components(&values_, &groups_);
|
||||
|
||||
|
@@ -774,6 +774,22 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) {
|
||||
visit_store_inst(x);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `cat`
|
||||
*/
|
||||
void generator::visit_cat_inst(ir::cat_inst* x) {
|
||||
auto idxs = idxs_.at(x);
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
int i = 0;
|
||||
for(size_t j = 0; j < idxs_.at(lhs).size(); j ++)
|
||||
vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]];
|
||||
for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){
|
||||
vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `reshape`
|
||||
@@ -861,6 +877,20 @@ void generator::visit_cos_inst(ir::cos_inst* x){
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `umulhi`
|
||||
*/
|
||||
void generator::visit_umulhi_inst(ir::umulhi_inst* x){
|
||||
std::vector<llvm::Type*> tys = {i32_ty, i32_ty};
|
||||
FunctionType *fn_ty = FunctionType::get(i32_ty, tys, false);
|
||||
InlineAsm *umulhi = InlineAsm::get(fn_ty, "mul.hi.u32 $0, $1, $2;", "=r,r,r", false);
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value* lhs = vals_[x->get_operand(0)][idx];
|
||||
Value* rhs = vals_[x->get_operand(1)][idx];
|
||||
vals_[x][idx] = call(umulhi, std::vector<llvm::Value*>{lhs, rhs});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `sin`
|
||||
*/
|
||||
|
@@ -11,6 +11,8 @@ namespace transform{
|
||||
|
||||
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
|
||||
std::set<ir::value*>& seen) {
|
||||
if (dynamic_cast<ir::phi_node*>(root))
|
||||
return root;
|
||||
if(!seen.insert(root).second)
|
||||
return root;
|
||||
if(!root->get_type()->is_block_ty())
|
||||
|
@@ -178,7 +178,7 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
|
||||
ofs.close();
|
||||
std::string cmd;
|
||||
int err;
|
||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o";
|
||||
err = system(cmd.c_str());
|
||||
CUmodule ret;
|
||||
std::ifstream _cubin(_fbin, std::ios::binary );
|
||||
|
@@ -297,6 +297,10 @@ value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(reshape_inst::create(arg, shapes));
|
||||
}
|
||||
|
||||
value *builder::create_cat(value *lhs, value *rhs) {
|
||||
return insert(cat_inst::create(lhs, rhs));
|
||||
}
|
||||
|
||||
value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) {
|
||||
return insert(splat_inst::create(arg, shapes));
|
||||
}
|
||||
@@ -369,6 +373,9 @@ value *builder::create_select(value *pred, value *if_value, value *else_value){
|
||||
// intrinsic instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_umulhi(value *lhs, value *rhs) {
|
||||
return insert(umulhi_inst::create(lhs, rhs));
|
||||
}
|
||||
|
||||
value *builder::create_copy_to_shared(value *arg) {
|
||||
return insert(copy_to_shared_inst::create(arg));
|
||||
|
@@ -368,6 +368,10 @@ ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *b
|
||||
return builder->create_reshape(input, dst_shape);
|
||||
}
|
||||
|
||||
ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||
return builder->create_cat(lhs, rhs);
|
||||
}
|
||||
|
||||
ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) {
|
||||
if (!input->get_type()->is_block_ty())
|
||||
return builder->create_splat(input, shape);
|
||||
@@ -715,6 +719,11 @@ ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *build
|
||||
// Math
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) {
|
||||
binary_op_type_checking(x, y, builder);
|
||||
return builder->insert(umulhi_inst::create(x, y));
|
||||
}
|
||||
|
||||
ir::value *dispatch::exp(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_exp(x);
|
||||
}
|
||||
|
@@ -522,11 +522,28 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cat
|
||||
|
||||
cat_inst::cat_inst(value *x, value *y, const std::string &name, instruction *next)
|
||||
: instruction(block_type::get(x->get_type()->get_scalar_ty(),
|
||||
{x->get_type()->get_block_shapes()[0] +
|
||||
y->get_type()->get_block_shapes()[0] }), INST_CAT, 2, name, next) {
|
||||
set_operand(0, x);
|
||||
set_operand(1, y);
|
||||
}
|
||||
|
||||
instruction* cat_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||
return new cat_inst(lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
// retile
|
||||
|
||||
retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes,
|
||||
const std::string &name, instruction *next)
|
||||
: unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
|
||||
|
||||
|
||||
|
||||
// reshape
|
||||
|
||||
instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes,
|
||||
@@ -761,6 +778,19 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
|
||||
}
|
||||
|
||||
|
||||
// umulhi
|
||||
|
||||
umulhi_inst::umulhi_inst(value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: builtin_inst(lhs->get_type(), INST_UMULHI, 2, name, next) {
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
|
||||
instruction* umulhi_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
|
||||
return new umulhi_inst(lhs, rhs, name, next);
|
||||
}
|
||||
|
||||
|
||||
// exp
|
||||
|
||||
exp_inst::exp_inst(value *val, const std::string &name, instruction *next)
|
||||
@@ -877,7 +907,7 @@ make_range::make_range(type *ty, constant_int *first, constant_int *last)
|
||||
make_range *make_range::create(constant_int *first, constant_int *last) {
|
||||
assert(first->get_type()->is_integer_ty());
|
||||
assert(first->get_type() == last->get_type());
|
||||
assert(((constant_int*)first)->get_value() == 0);
|
||||
// assert(((constant_int*)first)->get_value() == 0);
|
||||
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()});
|
||||
return new make_range(ty, first, last);
|
||||
}
|
||||
|
Reference in New Issue
Block a user