[GENERAL] Merged einsum feature branch. Various feature, performance
improvements and bugfixes: * Added preliminary support for extended Einstein summation in PyTriton * Significant performance improvement on FP32 kernels containing matrix multiplication * Added re-coalescing pass for FP16 kernels containing matrix multiplication * Various bugfixes
This commit is contained in:
@@ -48,6 +48,9 @@ value *builder::get_int32(unsigned val) {
|
||||
return constant_int::get(type::get_int32_ty(ctx_), val);
|
||||
}
|
||||
|
||||
type *builder::get_void_ty()
|
||||
{ return type::get_void_ty(ctx_); }
|
||||
|
||||
type *builder::get_int1_ty()
|
||||
{ return type::get_int1_ty(ctx_); }
|
||||
|
||||
@@ -132,19 +135,12 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved, const std::string
|
||||
return insert(binary_operator::create(OPCODE, lhs, rhs), name);\
|
||||
}
|
||||
|
||||
#define DEFINE_UNARY_FLOAT(SUFFIX)\
|
||||
value *builder::create_ ## SUFFIX(value *arg, const std::string &name){\
|
||||
return insert(binary_operator::create_ ## SUFFIX(arg), name);\
|
||||
}
|
||||
|
||||
// Binary
|
||||
DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul)
|
||||
DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv)
|
||||
DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem)
|
||||
DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd)
|
||||
DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub)
|
||||
// Unary
|
||||
DEFINE_UNARY_FLOAT(fneg)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -171,10 +167,7 @@ value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs,
|
||||
return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, false, false);\
|
||||
}
|
||||
|
||||
#define DEFINE_UNARY_INT(SUFFIX)\
|
||||
value *builder::create_ ## SUFFIX(value *arg, const std::string &name){\
|
||||
return insert(binary_operator::create_ ## SUFFIX(arg), name);\
|
||||
}
|
||||
|
||||
|
||||
// Binary
|
||||
DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul)
|
||||
@@ -190,9 +183,6 @@ DEFINE_BINARY_INT(urem, binary_op_t::URem)
|
||||
DEFINE_BINARY_INT(and, binary_op_t::And)
|
||||
DEFINE_BINARY_INT(or, binary_op_t::Or)
|
||||
DEFINE_BINARY_INT(xor, binary_op_t::Xor)
|
||||
// Unary
|
||||
DEFINE_UNARY_INT(neg)
|
||||
DEFINE_UNARY_INT(not)
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Reference in New Issue
Block a user