API: added diag(matrix)
This commit is contained in:
@@ -290,6 +290,10 @@ ISAACAPI math_expression eye(int_t, int_t, isaac::numeric_type, driver::Context
|
|||||||
ISAACAPI math_expression zeros(int_t M, int_t N, numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default());
|
ISAACAPI math_expression zeros(int_t M, int_t N, numeric_type dtype, driver::Context const & context = driver::backend::contexts::get_default());
|
||||||
ISAACAPI math_expression reshape(array const &, int_t, int_t);
|
ISAACAPI math_expression reshape(array const &, int_t, int_t);
|
||||||
|
|
||||||
|
//diag
|
||||||
|
ISAACAPI math_expression diag(array const &, int offset = 0);
|
||||||
|
ISAACAPI math_expression diag(math_expression const &, int offset = 0);
|
||||||
|
|
||||||
//Row
|
//Row
|
||||||
ISAACAPI math_expression row(array const &, value_scalar const &);
|
ISAACAPI math_expression row(array const &, value_scalar const &);
|
||||||
ISAACAPI math_expression row(array const &, for_idx_t const &);
|
ISAACAPI math_expression row(array const &, for_idx_t const &);
|
||||||
|
@@ -17,6 +17,8 @@ namespace isaac
|
|||||||
namespace detail
|
namespace detail
|
||||||
{
|
{
|
||||||
inline int_t max(size4 const & s) { return std::max(s[0], s[1]); }
|
inline int_t max(size4 const & s) { return std::max(s[0], s[1]); }
|
||||||
|
inline int_t min(size4 const & s) { return std::min(s[0], s[1]); }
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*--- Constructors ---*/
|
/*--- Constructors ---*/
|
||||||
@@ -658,6 +660,10 @@ math_expression cast(math_expression const & x, numeric_type dtype)
|
|||||||
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
isaac::math_expression eye(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||||
{ return math_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); }
|
{ return math_expression(value_scalar(1), value_scalar(0), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_VDIAG_TYPE), ctx, dtype, size4(M, N)); }
|
||||||
|
|
||||||
|
isaac::math_expression diag(array const & x, int offset)
|
||||||
|
{ return math_expression(x, value_scalar(offset), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_MATRIX_DIAG_TYPE), x.context(), x.dtype(), size4(detail::min(x.shape()), 1, 1, 1)); }
|
||||||
|
|
||||||
|
|
||||||
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
isaac::math_expression zeros(int_t M, int_t N, isaac::numeric_type dtype, driver::Context const & ctx)
|
||||||
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N)); }
|
{ return math_expression(value_scalar(0, dtype), invalid_node(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), ctx, dtype, size4(M, N)); }
|
||||||
|
|
||||||
|
@@ -118,7 +118,7 @@ std::string axpy::generate_impl(std::string const & suffix, math_expression cons
|
|||||||
std::string array_access = "#scalartype #namereg = #pointer[#index];";
|
std::string array_access = "#scalartype #namereg = #pointer[#index];";
|
||||||
std::string matrix_row = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i", "#pointer + #row*#stride", "#ld", backend, false) + ";";
|
std::string matrix_row = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i", "#pointer + #row*#stride", "#ld", backend, false) + ";";
|
||||||
std::string matrix_column = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i*#stride", "#pointer + #column*#ld", "#stride", backend, false) + ";";
|
std::string matrix_column = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i*#stride", "#pointer + #column*#ld", "#stride", backend, false) + ";";
|
||||||
std::string matrix_diag = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i*(#ld + #stride)", "#pointer + (#diag_offset<0)?-#diag_offset:(#diag_offset*#ld)", "#ld + #stride", backend, false) + ";";
|
std::string matrix_diag = dtype + " #namereg = " + vload(p_.simd_width, "#scalartype", "i*(#ld + #stride)", "#pointer + ((#diag_offset<0)?-#diag_offset:(#diag_offset*#ld))", "#ld + #stride", backend, false) + ";";
|
||||||
process(stream, RHS_NODE_TYPE, {{"array1", array1}, {"matrix_row", matrix_row}, {"matrix_column", matrix_column},
|
process(stream, RHS_NODE_TYPE, {{"array1", array1}, {"matrix_row", matrix_row}, {"matrix_column", matrix_column},
|
||||||
{"matrix_diag", matrix_diag}, {"array_access", array_access}}, expressions, idx, mappings, processed);
|
{"matrix_diag", matrix_diag}, {"array_access", array_access}}, expressions, idx, mappings, processed);
|
||||||
}
|
}
|
||||||
@@ -165,7 +165,7 @@ std::string axpy::generate_impl(std::string const & suffix, math_expression cons
|
|||||||
stream.dec_tab();
|
stream.dec_tab();
|
||||||
stream << "}" << std::endl;
|
stream << "}" << std::endl;
|
||||||
|
|
||||||
// std::cout << stream.str() << std::endl;
|
std::cout << stream.str() << std::endl;
|
||||||
|
|
||||||
return stream.str();
|
return stream.str();
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user