Kernels: merged start1, start2 and stride1, stride2 into start and stride for matrices

This commit is contained in:
Philippe Tillet
2015-08-10 22:45:48 -07:00
parent 5365b1331f
commit 963867574f
7 changed files with 26 additions and 35 deletions

View File

@@ -180,10 +180,8 @@ public:
mapped_array(std::string const & scalartype, unsigned int id, char type);
private:
std::string ld_;
std::string start1_;
std::string start2_;
std::string stride1_;
std::string stride2_;
std::string start_;
std::string stride_;
char type_;
};

View File

@@ -210,21 +210,18 @@ mapped_array::mapped_array(std::string const & scalartype, unsigned int id, char
{
if(type_ == 's')
{
register_attribute(start1_, "#start", name_ + "_start");
register_attribute(start_, "#start", name_ + "_start");
}
else if(type_=='m')
{
register_attribute(start1_, "#start1", name_ + "_start1");
register_attribute(start2_, "#start2", name_ + "_start2");
register_attribute(stride1_, "#stride1", name_ + "_stride1");
register_attribute(stride2_, "#stride2", name_ + "_stride2");
register_attribute(start_, "#start", name_ + "_start");
register_attribute(stride_, "#stride", name_ + "_stride");
register_attribute(ld_, "#ld", name_ + "_ld");
keywords_["#nldstride"] = "#stride2";
}
else
{
register_attribute(start1_, "#start", name_ + "_start");
register_attribute(stride1_, "#stride", name_ + "_stride");
register_attribute(start_, "#start", name_ + "_start");
register_attribute(stride_, "#stride", name_ + "_stride");
}
}

View File

@@ -65,9 +65,9 @@ std::string axpy::generate_impl(std::string const & suffix, expressions_tuple co
stream << "{" << std::endl;
stream.inc_tab();
process(stream, PARENT_NODE_TYPE, {{"array1", dtype + " #namereg = #pointer[i*#stride];"},
{"matrix_row", "#scalartype #namereg = $VALUE{#row*#stride1, i*#stride2};"},
{"matrix_column", "#scalartype #namereg = $VALUE{i*#stride1,#column*#stride2};"},
{"matrix_diag", "#scalartype #namereg = #pointer[#diag_offset<0?$OFFSET{(i - #diag_offset)*#stride1, i*#stride2}:$OFFSET{i*#stride1, (i + #diag_offset)*#stride2}];"}}, expressions, mappings);
{"matrix_row", "#scalartype #namereg = $VALUE{#row*#stride, i};"},
{"matrix_column", "#scalartype #namereg = $VALUE{i*#stride,#column};"},
{"matrix_diag", "#scalartype #namereg = #pointer[#diag_offset<0?$OFFSET{(i - #diag_offset)*#stride, i}:$OFFSET{i*#stride, (i + #diag_offset)}];"}}, expressions, mappings);
evaluate(stream, PARENT_NODE_TYPE, {{"array0", "#namereg"}, {"array1", "#namereg"},
@@ -79,7 +79,7 @@ std::string axpy::generate_impl(std::string const & suffix, expressions_tuple co
process(stream, LHS_NODE_TYPE, {{"array1", "#pointer[i*#stride] = #namereg;"},
{"matrix_row", "$VALUE{#row, i} = #namereg;"},
{"matrix_column", "$VALUE{i, #column} = #namereg;"},
{"matrix_diag", "#diag_offset<0?$VALUE{(i - #diag_offset)*#stride1, i*#stride2}:$VALUE{i*#stride1, (i + #diag_offset)*#stride2} = #namereg;"}}, expressions, mappings);
{"matrix_diag", "#diag_offset<0?$VALUE{(i - #diag_offset)*#stride, i}:$VALUE{i*#stride, (i + #diag_offset)} = #namereg;"}}, expressions, mappings);
stream.dec_tab();
stream << "}" << std::endl;

View File

@@ -140,9 +140,9 @@ std::string dot::generate_impl(std::string const & suffix, expressions_tuple con
//Fetch vector entry
for (const auto & elem : exprs)
(elem)->process_recursive(stream, PARENT_NODE_TYPE, {{"array1", append_width("#scalartype",simd_width) + " #namereg = " + vload(simd_width,"#scalartype",i,"#pointer",backend)+";"},
{"matrix_row", "#scalartype #namereg = #pointer[$OFFSET{#row*#stride, i*#stride2}];"},
{"matrix_column", "#scalartype #namereg = #pointer[$OFFSET{i*#stride,#column*#stride2}];"},
{"matrix_diag", "#scalartype #namereg = #pointer[#diag_offset<0?$OFFSET{(i - #diag_offset)*#stride, i*#stride2}:$OFFSET{i*#stride, (i + #diag_offset)*#stride2}];"}});
{"matrix_row", "#scalartype #namereg = #pointer[$OFFSET{#row*#stride, i}];"},
{"matrix_column", "#scalartype #namereg = #pointer[$OFFSET{i*#stride,#column}];"},
{"matrix_diag", "#scalartype #namereg = #pointer[#diag_offset<0?$OFFSET{(i - #diag_offset)*#stride, i}:$OFFSET{i*#stride, (i + #diag_offset)}];"}});
//Update accumulators
std::vector<std::string> str(simd_width);

View File

@@ -88,8 +88,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
process(stream, PARENT_NODE_TYPE,
{{"array0", "#scalartype #namereg = #pointer[#start];"},
{"array1", "#pointer += #start;"},
{"array2", "#pointer += #start1 + #start2*#ld; "
"#ld *= #nldstride; "}}, expressions, mappings);
{"array2", "#pointer += #start;"}}, expressions, mappings);
unsigned int local_size_0_ld = p_.local_size_0;
std::string local_size_0_ld_str = to_string(local_size_0_ld);
@@ -128,12 +127,12 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
std::map<std::string, std::string> accessors;
if(dot_type_==REDUCE_COLUMNS)
{
accessors["array2"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "c*#stride1", "#pointer + r*#ld", backend)+";";
accessors["array2"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "c*#stride", "#pointer + r*#ld", backend)+";";
accessors["repeat"] = data_type + " #namereg = " + vload(simd_width, "#scalartype", "(c%#tuplearg0)*#stride", "#pointer + (r%#tuplearg1)*#stride ", backend)+";";
}
else
{
accessors["array2"] = "#scalartype #namereg = #pointer[r*#stride1 + c*#ld];";
accessors["array2"] = "#scalartype #namereg = #pointer[r*#stride + c*#ld];";
accessors["repeat"] = "#scalartype #namereg = $VALUE{(r%#tuplearg0)*#stride, (c%#tuplearg1)*#stride};";
}
e->process_recursive(stream, PARENT_NODE_TYPE, accessors);
@@ -234,8 +233,7 @@ std::string gemv::generate_impl(std::string const & suffix, expressions_tuple co
process(stream, PARENT_NODE_TYPE,
{{"array0", "#scalartype #namereg = #pointer[#start];"},
{"array1", "#pointer += #start;"},
{"array2", "#pointer += #start1 + #start2*#ld; "
"#ld *= #nldstride; "}}, expressions, mappings);
{"array2", "#pointer += #start; "}}, expressions, mappings);
for (const auto & e : dots)
stream << e->process(Local(backend).get() + " #scalartype #name_buf[" + to_string(p_.local_size_1*local_size_0_ld) + "];") << std::endl;

View File

@@ -51,7 +51,7 @@ std::string ger::generate_impl(std::string const & suffix, expressions_tuple con
process(stream, PARENT_NODE_TYPE, { {"array0", "#scalartype #namereg = #pointer[#start];"},
{"array1", "#pointer += #start;"},
{"array2", "#pointer = &$VALUE{#start1, #start2};"}}
{"array2", "#pointer += #start;"}}
, expressions, mappings);
fetching_loop_info(p_.fetching_policy, "M", stream, init0, upper_bound0, inc0, GlobalIdx0(backend).get(), GlobalSize0(backend).get(), device);
@@ -63,9 +63,9 @@ std::string ger::generate_impl(std::string const & suffix, expressions_tuple con
stream << "{" << std::endl;
stream.inc_tab();
process(stream, PARENT_NODE_TYPE, { {"array2", data_type + " #namereg = $VALUE{i*#stride1,j*#stride2};"},
{"vdiag", "#scalartype #namereg = ((i + ((#diag_offset<0)?#diag_offset:0))!=(j-((#diag_offset>0)?#diag_offset:0)))?0:$VALUE{min(i*#stride1, j*#stride1)};"},
{"repeat", "#scalartype #namereg = $VALUE{(i%#tuplearg0)*#stride1, (j%#tuplearg1)*#stride2};"},
process(stream, PARENT_NODE_TYPE, { {"array2", data_type + " #namereg = $VALUE{i*#stride,j};"},
{"vdiag", "#scalartype #namereg = ((i + ((#diag_offset<0)?#diag_offset:0))!=(j-((#diag_offset>0)?#diag_offset:0)))?0:$VALUE{min(i*#stride, j*#stride)};"},
{"repeat", "#scalartype #namereg = $VALUE{(i%#tuplearg0)*#stride, (j%#tuplearg1)};"},
{"outer", "#scalartype #namereg = ($LVALUE{i*#stride})*($RVALUE{j*#stride});"} }
, expressions, mappings);
@@ -78,7 +78,7 @@ std::string ger::generate_impl(std::string const & suffix, expressions_tuple con
{"host_scalar", p_.simd_width==1?"#name": InitPrefix(backend, data_type).get() + "(#name)"}}
, expressions, mappings);
process(stream, LHS_NODE_TYPE, { {"array2", "$VALUE{i*#stride1,j*#stride2} = #namereg;"} } , expressions, mappings);
process(stream, LHS_NODE_TYPE, { {"array2", "$VALUE{i*#stride,j} = #namereg;"} } , expressions, mappings);
stream.dec_tab();
stream << "}" << std::endl;

View File

@@ -22,7 +22,7 @@ inline std::string generate_arguments(std::string const & data_type, driver::Dev
process(stream, PARENT_NODE_TYPE, { {"array0", kwglobal + " #scalartype* #pointer, " + _size_t + " #start,"},
{"host_scalar", "#scalartype #name,"},
{"array1", kwglobal + " " + data_type + "* #pointer, " + _size_t + " #start, " + _size_t + " #stride,"},
{"array2", kwglobal + " " + data_type + "* #pointer, " + _size_t + " #ld, " + _size_t + " #start1, " + _size_t + " #start2, " + _size_t + " #stride1, " + _size_t + " #stride2,"},
{"array2", kwglobal + " " + data_type + "* #pointer, " + _size_t + " #ld, " + _size_t + " #start, " + _size_t + " #stride, "},
{"tuple4", "#scalartype #name0, #scalartype #name1, #scalartype #name2, #scalartype #name3,"}}
, expressions, mappings);
@@ -81,11 +81,9 @@ public:
}
else
{
kernel_.setSizeArg(current_arg_++, a->ld());
kernel_.setSizeArg(current_arg_++, a->start()[0]);
kernel_.setSizeArg(current_arg_++, a->start()[1]);
kernel_.setSizeArg(current_arg_++, a->ld()*a->stride()[1]);
kernel_.setSizeArg(current_arg_++, a->start()[0] + a->start()[1]*a->ld());
kernel_.setSizeArg(current_arg_++, a->stride()[0]);
kernel_.setSizeArg(current_arg_++, a->stride()[1]);
}
}
}