API: added diag(matrix)

This commit is contained in:
Philippe Tillet
2015-10-04 17:05:06 -04:00
parent 740f5def49
commit 07e7bd862c
3 changed files with 12 additions and 2 deletions

View File

@@ -290,6 +290,10 @@ ISAACAPI math_expression eye(int_t, int_t, isaac::numeric_type, driver::Context
ISAACAPI math_expression zeros(int_t M, int_t N, numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default()); ISAACAPI math_expression zeros(int_t M, int_t N, numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default());
ISAACAPI math_expression reshape(array const &, int_t, int_t); ISAACAPI math_expression reshape(array const &, int_t, int_t);
//diag
ISAACAPI math_expression diag(array const &, int offset = 0);
ISAACAPI math_expression diag(math_expression const &, int offset = 0);
//Row //Row
ISAACAPI math_expression row(array const &, value_scalar const &); ISAACAPI math_expression row(array const &, value_scalar const &);
ISAACAPI math_expression row(array const &, for_idx_t const &); ISAACAPI math_expression row(array const &, for_idx_t const &);

View File

@@ -17,6 +17,8 @@ namespace isaac
namespace detail namespace detail
{ {
inline int_t max(size4 const & s) { return std::max(s[0], s[1]); } inline int_t max(size4 const & s) { return std::max(s[0], s[1]); }
inline int_t min(size4 const & s) { return std::min(s[0], s[1]); }
} }
/*--- Constructors ---*/ /*--- Constructors ---*/
@@ -658,6 +660,10 @@ math_expression cast(math_expression const & x, numeric_type dtype)
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx) isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); } { return math_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); }
isaac::math_expression diag(array const & x, int offset)
{ return math_expression(x, value_scalar(offset), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_DIAG_TYPE), x.context(), x.dtype(), size4(detail::min(x.shape()), 1, 1, 1)); }
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx) isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N)); } { return math_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N)); }

View File

@@ -118,7 +118,7 @@ std::string axpy::generate_impl(std::string const & suffix, math_expression cons
std::string array_access = "#scalartype #namereg = #pointer[#index];"; std::string array_access = "#scalartype #namereg = #pointer[#index];";
std::string matrix_row = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i", "#pointer + #row*#stride", "#ld", backend, false) + ";"; std::string matrix_row = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i", "#pointer + #row*#stride", "#ld", backend, false) + ";";
std::string matrix_column = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i*#stride", "#pointer + #column*#ld", "#stride", backend, false) + ";"; std::string matrix_column = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i*#stride", "#pointer + #column*#ld", "#stride", backend, false) + ";";
std::string matrix_diag = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i*(#ld + #stride)", "#pointer + (#diag_offset<0)?-#diag_offset:(#diag_offset*#ld)", "#ld + #stride", backend, false) + ";"; std::string matrix_diag = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i*(#ld + #stride)", "#pointer + ((#diag_offset<0)?-#diag_offset:(#diag_offset*#ld))", "#ld + #stride", backend, false) + ";";
process(stream, RHS_NODE_TYPE, {{"array1", array1}, {"matrix_row", matrix_row}, {"matrix_column", matrix_column}, process(stream, RHS_NODE_TYPE, {{"array1", array1}, {"matrix_row", matrix_row}, {"matrix_column", matrix_column},
{"matrix_diag", matrix_diag}, {"array_access", array_access}}, expressions, idx, mappings, processed); {"matrix_diag", matrix_diag}, {"array_access", array_access}}, expressions, idx, mappings, processed);
} }
@@ -165,7 +165,7 @@ std::string axpy::generate_impl(std::string const & suffix, math_expression cons
stream.dec_tab(); stream.dec_tab();
stream << "}" << std::endl; stream << "}" << std::endl;
// std::cout << stream.str() << std::endl; std::cout << stream.str() << std::endl;
return stream.str(); return stream.str();
} }