[GENERAL] Merged einsum feature branch. Various feature, performance

improvements and bugfixes:

* Added preliminary support for extended Einstein summation in PyTriton
* Significant performance improvement on FP32 kernels containing matrix
multiplication
* Added re-coalescing pass for FP16 kernels containing matrix
multiplication
* Various bugfixes
This commit is contained in:
Philippe Tillet
2020-01-16 12:09:50 -05:00
parent 50a52df489
commit f278d9741a
49 changed files with 1923 additions and 994 deletions

View File

@@ -48,6 +48,9 @@ value *builder::get_int32(unsigned val) {
return constant_int::get(type::get_int32_ty(ctx_), val);
}
type *builder::get_void_ty()
{ return type::get_void_ty(ctx_); }
type *builder::get_int1_ty()
{ return type::get_int1_ty(ctx_); }
@@ -132,19 +135,12 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved, const std::string
return insert(binary_operator::create(OPCODE, lhs, rhs), name);\
}
#define DEFINE_UNARY_FLOAT(SUFFIX)\
value *builder::create_ ## SUFFIX(value *arg, const std::string &name){\
return insert(binary_operator::create_ ## SUFFIX(arg), name);\
}
// Binary
DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul)
DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv)
DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem)
DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd)
DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub)
// Unary
DEFINE_UNARY_FLOAT(fneg)
//===----------------------------------------------------------------------===//
@@ -171,10 +167,7 @@ value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, false, false);\
}
#define DEFINE_UNARY_INT(SUFFIX)\
value *builder::create_ ## SUFFIX(value *arg, const std::string &name){\
return insert(binary_operator::create_ ## SUFFIX(arg), name);\
}
// Binary
DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul)
@@ -190,9 +183,6 @@ DEFINE_BINARY_INT(urem, binary_op_t::URem)
DEFINE_BINARY_INT(and, binary_op_t::And)
DEFINE_BINARY_INT(or, binary_op_t::Or)
DEFINE_BINARY_INT(xor, binary_op_t::Xor)
// Unary
DEFINE_UNARY_INT(neg)
DEFINE_UNARY_INT(not)
//===----------------------------------------------------------------------===//