[BACKEND] Added Int8 mma (#440)
This commit is contained in:
@@ -61,7 +61,8 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
// dot(a, b, c) + d -> dot(a, b, c + d)
|
||||
// d + dot(a, b, c) -> dot(a, b, c + d)
|
||||
auto add = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(add && add->get_op() == ir::binary_op_t::FAdd) {
|
||||
if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
|
||||
bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
|
||||
ir::value *lhs = add->get_operand(0);
|
||||
ir::value *rhs = add->get_operand(1);
|
||||
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
|
||||
@@ -72,11 +73,17 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
ir::value *other = (dot == lhs) ? rhs : lhs;
|
||||
ir::value *acc = dot->get_operand(2);
|
||||
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
|
||||
ir::constant_fp *_0 = nullptr;
|
||||
ir::constant *_0 = nullptr;
|
||||
if(splat)
|
||||
_0 = dynamic_cast<ir::constant_fp*>(splat->get_operand(0));
|
||||
if(!(_0 && _0->get_value() == 0.0))
|
||||
_0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
|
||||
if(!_0)
|
||||
return false;
|
||||
if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
|
||||
if (fp_0->get_value() != 0.0)
|
||||
return false;
|
||||
if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
|
||||
if (int_0->get_value() != 0)
|
||||
return false;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
builder.set_insert_point(add);
|
||||
|
Reference in New Issue
Block a user