Kernels: merged start1, start2 and stride1, stride2 into start and stride for matrices
This commit is contained in:
@@ -180,10 +180,8 @@ public:
|
||||
mapped_array(std::string const & scalartype, unsigned int id, char type);
|
||||
private:
|
||||
std::string ld_;
|
||||
std::string start1_;
|
||||
std::string start2_;
|
||||
std::string stride1_;
|
||||
std::string stride2_;
|
||||
std::string start_;
|
||||
std::string stride_;
|
||||
char type_;
|
||||
};
|
||||
|
||||
|
@@ -210,21 +210,18 @@ mapped_array::mapped_array(std::string const & scalartype, unsigned int id, char
|
||||
{
|
||||
if(type_ == 's')
|
||||
{
|
||||
register_attribute(start1_, "#start", name_ + "_start");
|
||||
register_attribute(start_, "#start", name_ + "_start");
|
||||
}
|
||||
else if(type_=='m')
|
||||
{
|
||||
register_attribute(start1_, "#start1", name_ + "_start1");
|
||||
register_attribute(start2_, "#start2", name_ + "_start2");
|
||||
register_attribute(stride1_, "#stride1", name_ + "_stride1");
|
||||
register_attribute(stride2_, "#stride2", name_ + "_stride2");
|
||||
register_attribute(start_, "#start", name_ + "_start");
|
||||
register_attribute(stride_, "#stride", name_ + "_stride");
|
||||
register_attribute(ld_, "#ld", name_ + "_ld");
|
||||
keywords_["#nldstride"] = "#stride2";
|
||||
}
|
||||
else
|
||||
{
|
||||
register_attribute(start1_, "#start", name_ + "_start");
|
||||
register_attribute(stride1_, "#stride", name_ + "_stride");
|
||||
register_attribute(start_, "#start", name_ + "_start");
|
||||
register_attribute(stride_, "#stride", name_ + "_stride");
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -65,9 +65,9 @@ std::string axpy::generate_impl(std::string const & suffix, expressions_tuple co
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
process(stream, PARENT_NODE_TYPE, {{"array1", dtype + " #namereg = #pointer[i*#stride];"},
|
||||
{"matrix_row", "#scalartype #namereg = $VALUE{#row*#stride1, i*#stride2};"},
|
||||
{"matrix_column", "#scalartype #namereg = $VALUE{i*#stride1,#column*#stride2};"},
|
||||
{"matrix_diag", "#scalartype #namereg = #pointer[#diag_offset<0?$OFFSET{(i - #diag_offset)*#stride1, i*#stride2}:$OFFSET{i*#stride1, (i + #diag_offset)*#stride2}];"}}, expressions, mappings);
|
||||
{"matrix_row", "#scalartype #namereg = $VALUE{#row*#stride, i};"},
|
||||
{"matrix_column", "#scalartype #namereg = $VALUE{i*#stride,#column};"},
|
||||
{"matrix_diag", "#scalartype #namereg = #pointer[#diag_offset<0?$OFFSET{(i - #diag_offset)*#stride, i}:$OFFSET{i*#stride, (i + #diag_offset)}];"}}, expressions, mappings);
|
||||
|
||||
|
||||
evaluate(stream, PARENT_NODE_TYPE, {{"array0", "#namereg"}, {"array1", "#namereg"},
|
||||
@@ -79,7 +79,7 @@ std::string axpy::generate_impl(std::string const & suffix, expressions_tuple co
|
||||
process(stream, LHS_NODE_TYPE, {{"array1", "#pointer[i*#stride] = #namereg;"},
|
||||
{"matrix_row", "$VALUE{#row, i} = #namereg;"},
|
||||
{"matrix_column", "$VALUE{i, #column} = #namereg;"},
|
||||
{"matrix_diag", "#diag_offset<0?$VALUE{(i - #diag_offset)*#stride1, i*#stride2}:$VALUE{i*#stride1, (i + #diag_offset)*#stride2} = #namereg;"}}, expressions, mappings);
|
||||
{"matrix_diag", "#diag_offset<0?$VALUE{(i - #diag_offset)*#stride, i}:$VALUE{i*#stride, (i + #diag_offset)} = #namereg;"}}, expressions, mappings);
|
||||
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
@@ -140,9 +140,9 @@ std::string dot::generate_impl(std::string const & suffix, expressions_tuple con
|
||||
//Fetch vector entry
|
||||
for (const auto & elem : exprs)
|
||||
(elem)->process_recursive(stream, PARENT_NODE_TYPE, {{"array1", append_width("#scalartype",simd_width) + " #namereg = " + vload(simd_width,"#scalartype",i,"#pointer",backend)+";"},
|
||||
{"matrix_row", "#scalartype #namereg = #pointer[$OFFSET{#row*#stride, i*#stride2}];"},
|
||||
{"matrix_column", "#scalartype #namereg = #pointer[$OFFSET{i*#stride,#column*#stride2}];"},
|
||||
{"matrix_diag", "#scalartype #namereg = #pointer[#diag_offset<0?$OFFSET{(i - #diag_offset)*#stride, i*#stride2}:$OFFSET{i*#stride, (i + #diag_offset)*#stride2}];"}});
|
||||
{"matrix_row", "#scalartype #namereg = #pointer[$OFFSET{#row*#stride, i}];"},
|
||||
{"matrix_column", "#scalartype #namereg = #pointer[$OFFSET{i*#stride,#column}];"},
|
||||
{"matrix_diag", "#scalartype #namereg = #pointer[#diag_offset<0?$OFFSET{(i - #diag_offset)*#stride, i}:$OFFSET{i*#stride, (i + #diag_offset)}];"}});
|
||||
|
||||
//Update accumulators
|
||||
std::vector<std::string> str(simd_width);
|
||||
|
@@ -88,8 +88,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
|
||||
process(stream, PARENT_NODE_TYPE,
|
||||
{{"array0", "#scalartype #namereg = #pointer[#start];"},
|
||||
{"array1", "#pointer += #start;"},
|
||||
{"array2", "#pointer += #start1 + #start2*#ld; "
|
||||
"#ld *= #nldstride; "}}, expressions, mappings);
|
||||
{"array2", "#pointer += #start;"}}, expressions, mappings);
|
||||
|
||||
unsigned int local_size_0_ld = p_.local_size_0;
|
||||
std::string local_size_0_ld_str = to_string(local_size_0_ld);
|
||||
@@ -128,12 +127,12 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
|
||||
std::map<std::string, std::string> accessors;
|
||||
if(dot_type_==REDUCE_COLUMNS)
|
||||
{
|
||||
accessors["array2"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "c*#stride1", "#pointer + r*#ld", backend)+";";
|
||||
accessors["array2"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "c*#stride", "#pointer + r*#ld", backend)+";";
|
||||
accessors["repeat"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "(c%#tuplearg0)*#stride", "#pointer + (r%#tuplearg1)*#stride ", backend)+";";
|
||||
}
|
||||
else
|
||||
{
|
||||
accessors["array2"] = "#scalartype #namereg = #pointer[r*#stride1 + c*#ld];";
|
||||
accessors["array2"] = "#scalartype #namereg = #pointer[r*#stride + c*#ld];";
|
||||
accessors["repeat"] = "#scalartype #namereg = $VALUE{(r%#tuplearg0)*#stride, (c%#tuplearg1)*#stride};";
|
||||
}
|
||||
e->process_recursive(stream, PARENT_NODE_TYPE, accessors);
|
||||
@@ -234,8 +233,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
|
||||
process(stream, PARENT_NODE_TYPE,
|
||||
{{"array0", "#scalartype #namereg = #pointer[#start];"},
|
||||
{"array1", "#pointer += #start;"},
|
||||
{"array2", "#pointer += #start1 + #start2*#ld; "
|
||||
"#ld *= #nldstride; "}}, expressions, mappings);
|
||||
{"array2", "#pointer += #start; "}}, expressions, mappings);
|
||||
|
||||
for (const auto & e : dots)
|
||||
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;
|
||||
|
@@ -51,7 +51,7 @@ std::string ger::generate_impl(std::string const & suffix, expressions_tuple con
|
||||
|
||||
process(stream, PARENT_NODE_TYPE, { {"array0", "#scalartype #namereg = #pointer[#start];"},
|
||||
{"array1", "#pointer += #start;"},
|
||||
{"array2", "#pointer = &$VALUE{#start1, #start2};"}}
|
||||
{"array2", "#pointer += #start;"}}
|
||||
, expressions, mappings);
|
||||
|
||||
fetching_loop_info(p_.fetching_policy, "M", stream, init0, upper_bound0, inc0, GlobalIdx0(backend).get(), GlobalSize0(backend).get(), device);
|
||||
@@ -63,9 +63,9 @@ std::string ger::generate_impl(std::string const & suffix, expressions_tuple con
|
||||
stream << "{" << std::endl;
|
||||
stream.inc_tab();
|
||||
|
||||
process(stream, PARENT_NODE_TYPE, { {"array2", data_type + " #namereg = $VALUE{i*#stride1,j*#stride2};"},
|
||||
{"vdiag", "#scalartype #namereg = ((i + ((#diag_offset<0)?#diag_offset:0))!=(j-((#diag_offset>0)?#diag_offset:0)))?0:$VALUE{min(i*#stride1, j*#stride1)};"},
|
||||
{"repeat", "#scalartype #namereg = $VALUE{(i%#tuplearg0)*#stride1, (j%#tuplearg1)*#stride2};"},
|
||||
process(stream, PARENT_NODE_TYPE, { {"array2", data_type + " #namereg = $VALUE{i*#stride,j};"},
|
||||
{"vdiag", "#scalartype #namereg = ((i + ((#diag_offset<0)?#diag_offset:0))!=(j-((#diag_offset>0)?#diag_offset:0)))?0:$VALUE{min(i*#stride, j*#stride)};"},
|
||||
{"repeat", "#scalartype #namereg = $VALUE{(i%#tuplearg0)*#stride, (j%#tuplearg1)};"},
|
||||
{"outer", "#scalartype #namereg = ($LVALUE{i*#stride})*($RVALUE{j*#stride});"} }
|
||||
, expressions, mappings);
|
||||
|
||||
@@ -78,7 +78,7 @@ std::string ger::generate_impl(std::string const & suffix, expressions_tuple con
|
||||
{"host_scalar", p_.simd_width==1?"#name": InitPrefix(backend, data_type).get() + "(#name)"}}
|
||||
, expressions, mappings);
|
||||
|
||||
process(stream, LHS_NODE_TYPE, { {"array2", "$VALUE{i*#stride1,j*#stride2} = #namereg;"} } , expressions, mappings);
|
||||
process(stream, LHS_NODE_TYPE, { {"array2", "$VALUE{i*#stride,j} = #namereg;"} } , expressions, mappings);
|
||||
|
||||
stream.dec_tab();
|
||||
stream << "}" << std::endl;
|
||||
|
@@ -22,7 +22,7 @@ inline std::string generate_arguments(std::string const & data_type, driver::Dev
|
||||
process(stream, PARENT_NODE_TYPE, { {"array0", kwglobal + " #scalartype* #pointer, " + _size_t + " #start,"},
|
||||
{"host_scalar", "#scalartype #name,"},
|
||||
{"array1", kwglobal + " " + data_type + "* #pointer, " + _size_t + " #start, " + _size_t + " #stride,"},
|
||||
{"array2", kwglobal + " " + data_type + "* #pointer, " + _size_t + " #ld, " + _size_t + " #start1, " + _size_t + " #start2, " + _size_t + " #stride1, " + _size_t + " #stride2,"},
|
||||
{"array2", kwglobal + " " + data_type + "* #pointer, " + _size_t + " #ld, " + _size_t + " #start, " + _size_t + " #stride, "},
|
||||
{"tuple4", "#scalartype #name0, #scalartype #name1, #scalartype #name2, #scalartype #name3,"}}
|
||||
, expressions, mappings);
|
||||
|
||||
@@ -81,11 +81,9 @@ public:
|
||||
}
|
||||
else
|
||||
{
|
||||
kernel_.setSizeArg(current_arg_++, a->ld());
|
||||
kernel_.setSizeArg(current_arg_++, a->start()[0]);
|
||||
kernel_.setSizeArg(current_arg_++, a->start()[1]);
|
||||
kernel_.setSizeArg(current_arg_++, a->ld()*a->stride()[1]);
|
||||
kernel_.setSizeArg(current_arg_++, a->start()[0] + a->start()[1]*a->ld());
|
||||
kernel_.setSizeArg(current_arg_++, a->stride()[0]);
|
||||
kernel_.setSizeArg(current_arg_++, a->stride()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user