diff --git a/include/isaac/jit/generation/base.h b/include/isaac/jit/generation/base.h index 3e846fa54..1e26222ee 100644 --- a/include/isaac/jit/generation/base.h +++ b/include/isaac/jit/generation/base.h @@ -73,15 +73,6 @@ static const int TEMPLATE_BLOCK_SIZE_TOO_LARGE = -20; class base { -public: - struct parameters_type - { - parameters_type(uint32_t _vwidth, int_t _ls0, int_t _ls1, int_t _nkernels); - uint32_t vwidth; - uint32_t ls0; - uint32_t ls1; - uint32_t nkernels; - }; private: virtual std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const = 0; public: @@ -97,20 +88,20 @@ public: }; -template class base_impl : public base { private: virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const; public: - typedef ParametersType parameters_type; - base_impl(parameters_type const & parameters); + base_impl(uint32_t _vwidth, int_t _ls0, int_t _ls1); uint32_t ls0() const; uint32_t ls1() const; /** @brief returns whether or not the profile has undefined behavior on particular device */ int is_invalid(expression_tree const & expressions, driver::Device const & device) const; protected: - parameters_type p_; + uint32_t vwidth_; + uint32_t ls0_; + uint32_t ls1_; }; } diff --git a/include/isaac/jit/generation/elementwise_1d.h b/include/isaac/jit/generation/elementwise_1d.h index 368162ae9..c2c0de462 100644 --- a/include/isaac/jit/generation/elementwise_1d.h +++ b/include/isaac/jit/generation/elementwise_1d.h @@ -29,24 +29,18 @@ namespace isaac namespace templates { -class elementwise_1d_parameters : public base::parameters_type -{ -public: - elementwise_1d_parameters(uint32_t _vwidth, uint32_t _group_size, uint32_t _ng, fetch_type _fetch); - uint32_t ng; - fetch_type fetch; -}; - -class elementwise_1d : public base_impl +class elementwise_1d : public base_impl { private: virtual int is_invalid_impl(driver::Device const &, expression_tree const &) const; std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & symbols) const; public: - elementwise_1d(elementwise_1d::parameters_type const & parameters); - elementwise_1d(uint32_t _vwidth, uint32_t _group_size, uint32_t _ng, fetch_type _fetch); + elementwise_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch); std::vector input_sizes(expression_tree const & expressions) const; void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); +private: + uint32_t ng_; + fetch_type fetch_; }; } diff --git a/include/isaac/jit/generation/elementwise_2d.h b/include/isaac/jit/generation/elementwise_2d.h index 3acbe51a4..e2dd61ec0 100644 --- a/include/isaac/jit/generation/elementwise_2d.h +++ b/include/isaac/jit/generation/elementwise_2d.h @@ -30,17 +30,7 @@ namespace isaac namespace templates { -class elementwise_2d_parameters : public base::parameters_type -{ -public: - elementwise_2d_parameters(uint32_t _vwidth, uint32_t _ls0, uint32_t _ls1, uint32_t _ng0, uint32_t _ng1, fetch_type _fetch); - - uint32_t ng0; - uint32_t ng1; - fetch_type fetch; -}; - -class elementwise_2d : public base_impl +class elementwise_2d : public base_impl { private: int is_invalid_impl(driver::Device const &, expression_tree const &) const; @@ -50,6 +40,10 @@ public: elementwise_2d(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, fetch_type fetch); std::vector input_sizes(expression_tree const & expressions) const; void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); +private: + uint32_t ng0_; + uint32_t ng1_; + fetch_type fetch_; }; } diff --git a/include/isaac/jit/generation/gemm.h b/include/isaac/jit/generation/gemm.h index da0175822..81733514a 100644 --- a/include/isaac/jit/generation/gemm.h +++ b/include/isaac/jit/generation/gemm.h @@ -31,35 +31,7 @@ namespace isaac namespace templates { -struct gemm_parameters : public base::parameters_type -{ - gemm_parameters(uint32_t vwidth - ,uint32_t ls0, uint32_t KL, uint32_t ls1, uint32_t D - ,uint32_t ms, uint32_t ks, uint32_t ns - ,fetch_type Afetch, fetch_type Bfetch - ,uint32_t lf0, uint32_t lf1); - - uint32_t kL; - uint32_t depth; - - uint32_t mS; - uint32_t kS; - uint32_t nS; - - fetch_type Afetch; - fetch_type Bfetch; - - uint32_t lf0; - uint32_t lf1; - - uint32_t mL; - uint32_t nL; - - bool prefetch; - bool unroll_outer; -}; - -class gemm : public base_impl +class gemm : public base_impl { private: uint32_t temporary_workspace(expression_tree const & expressions) const; @@ -71,10 +43,32 @@ private: value_scalar const &alpha, value_scalar const &beta, driver::Program const & program, std::string const & suffix, runtime::execution_options_type const & options); std::vector infos(expression_tree const & expressions, isaac::symbolic::preset::gemm::args &arguments) const; public: - gemm(gemm::parameters_type const & parameters, char A_trans, char B_trans); + gemm(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D + , int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch + , int_t lf0, int_t lf1, char A_trans, char B_trans); std::vector input_sizes(expression_tree const & expressions) const; void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &ctr); private: + //Parameters + uint32_t kL_; + uint32_t depth_; + + uint32_t mS_; + uint32_t kS_; + uint32_t nS_; + + fetch_type Afetch_; + fetch_type Bfetch_; + + uint32_t lf0_; + uint32_t lf1_; + + uint32_t mL_; + uint32_t nL_; + + bool prefetch_; + bool unroll_outer_; + // const char A_trans_; const char B_trans_; expression_type type_; @@ -83,7 +77,7 @@ private: class gemm_nn : public gemm { public: - gemm_nn(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D + gemm_nn(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t lf0, int_t lf1); }; @@ -91,7 +85,7 @@ public: class gemm_tn : public gemm { public: - gemm_tn(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D + gemm_tn(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t lf0, int_t lf1); }; @@ -100,7 +94,7 @@ public: class gemm_nt : public gemm { public: - gemm_nt(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D + gemm_nt(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t lf0, int_t lf1); }; @@ -109,7 +103,7 @@ public: class gemm_tt : public gemm { public: - gemm_tt(uint32_t simd, int_t ls0, int_t KL, int_t ls1, int_t D + gemm_tt(uint32_t vwidth, int_t ls0, int_t KL, int_t ls1, int_t D , int_t ms, int_t ks, int_t ns, fetch_type Afetch , fetch_type Bfetch , int_t lf0, int_t lf1); }; diff --git a/include/isaac/jit/generation/reduce_1d.h b/include/isaac/jit/generation/reduce_1d.h index 139ead036..f914cdb11 100644 --- a/include/isaac/jit/generation/reduce_1d.h +++ b/include/isaac/jit/generation/reduce_1d.h @@ -29,16 +29,7 @@ namespace isaac namespace templates { -struct reduce_1d_parameters : public base::parameters_type -{ - reduce_1d_parameters(uint32_t _vwidth, - uint32_t _group_size, uint32_t _ng, - fetch_type _fetch); - uint32_t ng; - fetch_type fetch; -}; - -class reduce_1d : public base_impl +class reduce_1d : public base_impl { private: uint32_t lmem_usage(expression_tree const & expressions) const; @@ -49,11 +40,12 @@ private: std::string generate_impl(std::string const & suffix, expression_tree const & expressions, driver::Device const & device, symbolic::symbols_table const & mapping) const; public: - reduce_1d(reduce_1d::parameters_type const & parameters); - reduce_1d(uint32_t simd, uint32_t ls, uint32_t ng, fetch_type fetch); + reduce_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch); std::vector input_sizes(expression_tree const & expressions) const; void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); private: + uint32_t ng_; + fetch_type fetch_; std::vector< driver::Buffer > tmp_; std::vector< driver::Buffer > tmpidx_; }; diff --git a/include/isaac/jit/generation/reduce_2d.h b/include/isaac/jit/generation/reduce_2d.h index c6a6f7b75..9039c5a38 100644 --- a/include/isaac/jit/generation/reduce_2d.h +++ b/include/isaac/jit/generation/reduce_2d.h @@ -31,21 +31,11 @@ namespace isaac { namespace templates { -struct reduce_2d_parameters : public base::parameters_type -{ - reduce_2d_parameters(uint32_t _vwidth, - uint32_t _ls0, uint32_t _ls1, - uint32_t _ng0, uint32_t _ng1, fetch_type _fetch_policy); - uint32_t ng0; - uint32_t ng1; - fetch_type fetch_policy; -}; - class reduce_2d : public base_impl { protected: - reduce_2d(reduce_2d::parameters_type const & , operation_type_family); + reduce_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch, operation_type_family); private: int is_invalid_impl(driver::Device const &, expression_tree const &) const; uint32_t lmem_usage(expression_tree const &) const; @@ -55,21 +45,22 @@ public: virtual std::vector input_sizes(expression_tree const & expressions) const; void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); private: + uint32_t ng0_; + uint32_t ng1_; + fetch_type fetch_; operation_type_family reduction_type_; }; class reduce_2d_rows : public reduce_2d { public: - reduce_2d_rows(reduce_2d::parameters_type const &); - reduce_2d_rows(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, fetch_type fetch); + reduce_2d_rows(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch); }; class reduce_2d_cols : public reduce_2d { public: - reduce_2d_cols(reduce_2d::parameters_type const &); - reduce_2d_cols(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, fetch_type fetch); + reduce_2d_cols(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch); }; } diff --git a/lib/jit/generation/base.cpp b/lib/jit/generation/base.cpp index 9d0ce99e4..aa8a0f18c 100644 --- a/lib/jit/generation/base.cpp +++ b/lib/jit/generation/base.cpp @@ -40,9 +40,6 @@ namespace isaac namespace templates { -base::parameters_type::parameters_type(uint32_t _vwidth, int_t _ls0, int_t _ls1, int_t _nkernels) : vwidth(_vwidth), ls0(_ls0), ls1(_ls1), nkernels(_nkernels) -{ } - base::base() {} @@ -69,24 +66,20 @@ std::string base::generate(std::string const & suffix, expression_tree const & return generate_impl(suffix, expression, device, mapping); } -template -int base_impl::is_invalid_impl(driver::Device const &, expression_tree const &) const +int base_impl::is_invalid_impl(driver::Device const &, expression_tree const &) const { return TEMPLATE_VALID; } -template -base_impl::base_impl(parameters_type const & parameters) : base(), p_(parameters) +base_impl::base_impl(uint32_t vwidth, int_t ls0, int_t ls1): vwidth_(vwidth), ls0_(ls0), ls1_(ls1) { } -template -uint32_t base_impl::ls0() const +uint32_t base_impl::ls0() const { return p_.ls0; } -template -uint32_t base_impl::ls1() const +uint32_t base_impl::ls1() const { return p_.ls1; } template -int base_impl::is_invalid(expression_tree const & expressions, driver::Device const & device) const +int base_impl::is_invalid(expression_tree const & expressions, driver::Device const & device) const { //Query device informations size_t lmem_available = device.local_mem_size(); @@ -112,11 +105,5 @@ int base_impl::is_invalid(expression_tree const & expressions, dr return is_invalid_impl(device, expressions); } -template class base_impl; -template class base_impl; -template class base_impl; -template class base_impl; -template class base_impl; - } } diff --git a/lib/jit/generation/elementwise_1d.cpp b/lib/jit/generation/elementwise_1d.cpp index 5f7105c29..12bf51db2 100644 --- a/lib/jit/generation/elementwise_1d.cpp +++ b/lib/jit/generation/elementwise_1d.cpp @@ -36,14 +36,6 @@ namespace isaac namespace templates { -elementwise_1d_parameters::elementwise_1d_parameters(uint32_t _vwidth, - uint32_t _group_size, uint32_t _ng, - fetch_type _fetch) : - base::parameters_type(_vwidth, _group_size, 1, 1), ng(_ng), fetch(_fetch) -{ -} - - int elementwise_1d::is_invalid_impl(driver::Device const &, expression_tree const &) const { if (p_.fetch==FETCH_FROM_LOCAL) @@ -118,13 +110,8 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression return stream.str(); } -elementwise_1d::elementwise_1d(elementwise_1d_parameters const & parameters) : - base_impl(parameters) -{} - -elementwise_1d::elementwise_1d(uint32_t simd, uint32_t ls, uint32_t ng, - fetch_type fetch): - base_impl(elementwise_1d_parameters(simd,ls,ng,fetch)) +elementwise_1d::elementwise_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch): + base_impl(vwidth,ls), ng_(ng), fetch_(fetch) {} diff --git a/lib/jit/generation/elementwise_2d.cpp b/lib/jit/generation/elementwise_2d.cpp index 7d48cf329..9cc0c7958 100644 --- a/lib/jit/generation/elementwise_2d.cpp +++ b/lib/jit/generation/elementwise_2d.cpp @@ -33,13 +33,6 @@ namespace isaac namespace templates { -elementwise_2d_parameters::elementwise_2d_parameters(uint32_t _vwidth, - uint32_t _ls0, uint32_t _ls1, - uint32_t _ng0, uint32_t _ng1, - fetch_type _fetch) : base::parameters_type(_vwidth, _ls0, _ls1, 1), ng0(_ng0), ng1(_ng1), fetch(_fetch){ } - - - int elementwise_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const { if (p_.vwidth>1) @@ -111,12 +104,9 @@ std::string elementwise_2d::generate_impl(std::string const & suffix, expression return stream.str(); } -elementwise_2d::elementwise_2d(parameters_type const & parameters) : - base_impl(parameters){ } - -elementwise_2d::elementwise_2d(uint32_t simd, uint32_t ls1, uint32_t ls2, - uint32_t ng1, uint32_t ng2, fetch_type fetch): - base_impl(elementwise_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch)) +elementwise_2d::elementwise_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, + uint32_t ng0, uint32_t ng1, fetch_type fetch): + base_impl(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1), fetch_(fetch) {} std::vector elementwise_2d::input_sizes(expression_tree const & expression) const{ diff --git a/lib/jit/generation/gemm.cpp b/lib/jit/generation/gemm.cpp index 82ad73bf5..3aeeea44a 100644 --- a/lib/jit/generation/gemm.cpp +++ b/lib/jit/generation/gemm.cpp @@ -37,18 +37,6 @@ namespace isaac namespace templates { -gemm_parameters::gemm_parameters(uint32_t vwidth - ,uint32_t ls0, uint32_t KL, uint32_t ls1, uint32_t D - ,uint32_t ms, uint32_t ks, uint32_t ns - ,fetch_type Afetch, fetch_type Bfetch - ,uint32_t lf0, uint32_t lf1): base::parameters_type(vwidth, ls0, ls1, 1), - kL(KL), depth(D), mS(ms), kS(ks), nS(ns), Afetch(Afetch), Bfetch(Bfetch), - lf0(lf0), lf1(lf1), - mL(ms*ls0), nL(ns*ls1) -{ -} - - uint32_t gemm::lmem_usage(expression_tree const & expression) const { uint32_t N = 0; @@ -689,7 +677,13 @@ gemm_parameters::gemm_parameters(uint32_t vwidth return {M, N, K}; } - gemm::gemm(gemm_parameters const & parameters, char A_trans, char B_trans) : base_impl(parameters), A_trans_(A_trans), B_trans_(B_trans) + gemm::gemm(uint32_t vwidth + ,int_t ls0, int_t kL, int_t ls1, int_t D + ,int_t ms, int_t ks, int_t ns + ,fetch_type Afetch , fetch_type Bfetch + ,int_t lf0, int_t lf1, char A_trans, char B_trans) : + base_impl(vwidth, ls0, ls1), kL_(kL), depth_(D), mS_(ms), kS_(ks), nS_(ns), + Afetch_(Afetch), Bfetch_(Bfetch), lf0_(lf0), lf1_(lf1), A_trans_(A_trans), B_trans_(B_trans) { if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN; else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN; @@ -721,40 +715,40 @@ gemm_parameters::gemm_parameters(uint32_t vwidth } // - gemm_nn::gemm_nn(uint32_t simd + gemm_nn::gemm_nn(uint32_t vwidth , int_t ls0, int_t KL, int_t ls1, int_t D , int_t ms, int_t ks, int_t ns , fetch_type Afetch , fetch_type Bfetch , int_t lf0, int_t lf1) : - gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'N', 'N') + gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'N', 'N') { } // - gemm_tn::gemm_tn(uint32_t simd + gemm_tn::gemm_tn(uint32_t vwidth , int_t ls0, int_t KL, int_t ls1, int_t D , int_t ms, int_t ks, int_t ns , fetch_type Afetch , fetch_type Bfetch , int_t lf0, int_t lf1) : - gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'T', 'N') + gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'T', 'N') { } // - gemm_nt::gemm_nt(uint32_t simd + gemm_nt::gemm_nt(uint32_t vwidth , int_t ls0, int_t KL, int_t ls1, int_t D , int_t ms, int_t ks, int_t ns , fetch_type Afetch , fetch_type Bfetch , int_t lf0, int_t lf1) : - gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'N', 'T') + gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'N', 'T') { } // - gemm_tt::gemm_tt(uint32_t simd + gemm_tt::gemm_tt(uint32_t vwidth , int_t ls0, int_t KL, int_t ls1, int_t D , int_t ms, int_t ks, int_t ns , fetch_type Afetch , fetch_type Bfetch , int_t lf0, int_t lf1) : - gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'T', 'T') + gemm(vwidth, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1, 'T', 'T') { } } diff --git a/lib/jit/generation/reduce_1d.cpp b/lib/jit/generation/reduce_1d.cpp index bea6ccf48..3f7b7e12b 100644 --- a/lib/jit/generation/reduce_1d.cpp +++ b/lib/jit/generation/reduce_1d.cpp @@ -35,10 +35,6 @@ namespace isaac { namespace templates { -reduce_1d_parameters::reduce_1d_parameters(uint32_t _vwidth, - uint32_t _group_size, uint32_t _ng, - fetch_type _fetch) : base::parameters_type(_vwidth, _group_size, 1, 2), ng(_ng), fetch(_fetch) -{ } uint32_t reduce_1d::lmem_usage(expression_tree const & x) const { @@ -253,11 +249,8 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree return stream.str(); } -reduce_1d::reduce_1d(reduce_1d::parameters_type const & parameters) : base_impl(parameters) -{ } - -reduce_1d::reduce_1d(uint32_t simd, uint32_t ls, uint32_t ng, fetch_type fetch): - base_impl(reduce_1d_parameters(simd,ls,ng,fetch)) +reduce_1d::reduce_1d(uint32_t vwidth, uint32_t ls, uint32_t ng, fetch_type fetch): + base_impl(vwidth,ls), ng_(ng), fetch_(fetch) {} std::vector reduce_1d::input_sizes(expression_tree const & x) const diff --git a/lib/jit/generation/reduce_2d.cpp b/lib/jit/generation/reduce_2d.cpp index f251b438e..fd722796a 100644 --- a/lib/jit/generation/reduce_2d.cpp +++ b/lib/jit/generation/reduce_2d.cpp @@ -39,15 +39,9 @@ namespace isaac namespace templates { -reduce_2d_parameters::reduce_2d_parameters(uint32_t _vwidth, - uint32_t _ls0, uint32_t _ls1, - uint32_t _ng0, uint32_t _ng1, fetch_type _fetch_policy): base::parameters_type(_vwidth, _ls0, _ls1, 1), -ng0(_ng0), ng1(_ng1), fetch_policy(_fetch_policy) { } - - int reduce_2d::is_invalid_impl(driver::Device const &, expression_tree const &) const { - if (p_.fetch_policy==FETCH_FROM_LOCAL) + if (p_.fetch==FETCH_FROM_LOCAL) return TEMPLATE_INVALID_FETCHING_POLICY_TYPE; return TEMPLATE_VALID; } @@ -127,7 +121,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree std::ostringstream upper; upper << "(M +" << p_.ls1 - 1 << ")/" << p_.ls1 << "*" << p_.ls1; - element_wise_loop_1D(stream, p_.fetch_policy, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](uint32_t cwidth) + element_wise_loop_1D(stream, p_.fetch, (reduction_type_==REDUCE_ROWS)?1:1, "r", upper.str(), "$GLOBAL_IDX_1", "$GLOBAL_SIZE_1", device, [&](uint32_t cwidth) { //Declare Buffers for (symbolic::reduce_2d* rd : reductions) @@ -142,7 +136,7 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree stream << "if (r < M)" << std::endl; stream << "{" << std::endl; stream.inc_tab(); - element_wise_loop_1D(stream, p_.fetch_policy, (reduction_type_==REDUCE_COLUMNS)?p_.vwidth:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t rwidth) + element_wise_loop_1D(stream, p_.fetch, (reduction_type_==REDUCE_COLUMNS)?p_.vwidth:1, "c", "N", "$GLOBAL_IDX_0", "$GLOBAL_SIZE_0", device, [&](uint32_t rwidth) { std::string rdtype = append_width("#scalartype", rwidth); std::string cdtype = append_width("#scalartype", cwidth); @@ -282,9 +276,9 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree return stream.str(); } -reduce_2d::reduce_2d(reduce_2d::parameters_type const & parameters, - operation_type_family rtype) : - base_impl(parameters), +reduce_2d::reduce_2d(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, fetch_type fetch, + operation_type_family rtype) : + base_impl(vwidth, ls0, ls1), ng0_(ng0), ng1_(ng1), fetch_(fetch), reduction_type_(rtype){ } std::vector reduce_2d::input_sizes(expression_tree const & tree) const @@ -331,15 +325,11 @@ void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & pr control.execution_options().enqueue(program.context(), kernels[i], global[i], local[i]); } -reduce_2d_rows::reduce_2d_rows(reduce_2d_parameters const & parameters): reduce_2d(parameters, REDUCE_ROWS){} +reduce_2d_rows::reduce_2d_rows(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, + fetch_type fetch): reduce_2d(vwidth, ls0, ls1, ng0, ng1, fetch, REDUCE_ROWS) {} -reduce_2d_rows::reduce_2d_rows(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, - fetch_type fetch): reduce_2d(reduce_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_ROWS) {} - -reduce_2d_cols::reduce_2d_cols(reduce_2d::parameters_type const & parameters): reduce_2d(parameters, REDUCE_COLUMNS){} - -reduce_2d_cols::reduce_2d_cols(uint32_t simd, uint32_t ls1, uint32_t ls2, uint32_t ng1, uint32_t ng2, - fetch_type fetch): reduce_2d(reduce_2d_parameters(simd, ls1, ls2, ng1, ng2, fetch), REDUCE_COLUMNS) {} +reduce_2d_cols::reduce_2d_cols(uint32_t vwidth, uint32_t ls0, uint32_t ls1, uint32_t ng0, uint32_t ng1, + fetch_type fetch): reduce_2d(vwidth, ls0, ls1, ng0, ng1, fetch, REDUCE_COLUMNS) {} }