API: added diag(matrix)

This commit is contained in:
Philippe Tillet
2015-10-04 17:05:06 -04:00
parent 740f5def49
commit 07e7bd862c
3 changed files with 12 additions and 2 deletions

View File

@@ -17,6 +17,8 @@ namespace isaac
namespace detail
{
inline int_t max(size4 const & s) { return std::max(s[0], s[1]); }
inline int_t min(size4 const & s) { return std::min(s[0], s[1]); }
}
/*--- Constructors ---*/
@@ -658,6 +660,10 @@ math_expression cast(math_expression const & x, numeric_type dtype)
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); }
isaac::math_expression diag(array const & x, int offset)
{ return math_expression(x, value_scalar(offset), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_DIAG_TYPE), x.context(), x.dtype(), size4(detail::min(x.shape()), 1, 1, 1)); }
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N)); }