API: added diag(matrix)
This commit is contained in:
@@ -17,6 +17,8 @@ namespace isaac
|
||||
namespace detail
|
||||
{
|
||||
inline int_t max(size4 const & s) { return std::max(s[0], s[1]); }
|
||||
inline int_t min(size4 const & s) { return std::min(s[0], s[1]); }
|
||||
|
||||
}
|
||||
|
||||
/*--- Constructors ---*/
|
||||
@@ -658,6 +660,10 @@ math_expression cast(math_expression const & x, numeric_type dtype)
|
||||
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return math_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); }
|
||||
|
||||
isaac::math_expression diag(array const & x, int offset)
|
||||
{ return math_expression(x, value_scalar(offset), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_DIAG_TYPE), x.context(), x.dtype(), size4(detail::min(x.shape()), 1, 1, 1)); }
|
||||
|
||||
|
||||
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N)); }
|
||||
|
||||
|
Reference in New Issue
Block a user