[CODEGEN] Various bugfixes that make it possible to fuse RNG in a matmul epilogue (#356)

This commit is contained in:
Philippe Tillet
2021-10-24 02:30:46 -07:00
committed by GitHub
parent 858dec8372
commit 5ce1b726dc
17 changed files with 149 additions and 60 deletions

View File

@@ -522,11 +522,28 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask
// retile_inst classes
//===----------------------------------------------------------------------===//
// cat
cat_inst::cat_inst(value *x, value *y, const std::string &name, instruction *next)
: instruction(block_type::get(x->get_type()->get_scalar_ty(),
{x->get_type()->get_block_shapes()[0] +
y->get_type()->get_block_shapes()[0] }), INST_CAT, 2, name, next) {
set_operand(0, x);
set_operand(1, y);
}
instruction* cat_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
return new cat_inst(lhs, rhs, name, next);
}
// retile
retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes,
const std::string &name, instruction *next)
: unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
// reshape
instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes,
@@ -761,6 +778,19 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
}
// umulhi
umulhi_inst::umulhi_inst(value *lhs, value *rhs, const std::string &name, instruction *next)
: builtin_inst(lhs->get_type(), INST_UMULHI, 2, name, next) {
set_operand(0, lhs);
set_operand(1, rhs);
}
instruction* umulhi_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) {
return new umulhi_inst(lhs, rhs, name, next);
}
// exp
exp_inst::exp_inst(value *val, const std::string &name, instruction *next)
@@ -877,7 +907,7 @@ make_range::make_range(type *ty, constant_int *first, constant_int *last)
make_range *make_range::create(constant_int *first, constant_int *last) {
assert(first->get_type()->is_integer_ty());
assert(first->get_type() == last->get_type());
assert(((constant_int*)first)->get_value() == 0);
// assert(((constant_int*)first)->get_value() == 0);
type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()});
return new make_range(ty, first, last);
}