Tuning: added ger default sizes
This commit is contained in:
@@ -46,7 +46,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
|
||||
{
|
||||
std::vector<int_t> MNK = input_sizes(expressions);
|
||||
// int_t M = MNK[0]; int_t N = MNK[1];
|
||||
int_t M = MNK[0]; int_t N = MNK[1];
|
||||
|
||||
if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL)
|
||||
throw operation_not_supported_exception("Only local memory is supported for GEMM");
|
||||
|
@@ -132,7 +132,9 @@ void model::execute(controller<expressions_tuple> const & expr)
|
||||
}
|
||||
|
||||
model::templates_container const & model::templates() const
|
||||
{ return templates_; }
|
||||
{
|
||||
return templates_;
|
||||
}
|
||||
|
||||
///////////////////
|
||||
|
||||
|
@@ -7,6 +7,10 @@ from sklearn import ensemble
|
||||
import isaac as isc
|
||||
import optimize, tools, model
|
||||
|
||||
from json import encoder
|
||||
encoder.FLOAT_REPR = lambda o: format(o, '.2f')
|
||||
encoder.separators = (',',':')
|
||||
|
||||
def unique(L):
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
@@ -23,7 +27,7 @@ def tune(device, operation, json_path):
|
||||
|
||||
#List of size tuples to use
|
||||
sizes = {}
|
||||
sizes[isc.templates.axpy] = [(x,) for x in tools.expspace(1e3, 1e7, 4)]
|
||||
sizes[isc.templates.axpy] = [(x,) for x in tools.expspace(1e3, 1e8, 4)]
|
||||
sizes[isc.templates.gemv_n] = product(pow2range(4,17), pow2range(4,17))
|
||||
sizes[isc.templates.gemv_t] = sizes[isc.templates.gemv_n]
|
||||
sizes[isc.templates.gemm_nn] = product(pow2range(6, 12), pow2range(6, 12), pow2range(6, 12))
|
||||
@@ -31,6 +35,9 @@ def tune(device, operation, json_path):
|
||||
sizes[isc.templates.gemm_nt] = sizes[isc.templates.gemm_nn]
|
||||
sizes[isc.templates.gemm_tt] = sizes[isc.templates.gemm_nn]
|
||||
|
||||
#ger
|
||||
sizes[isc.templates.ger] = [(1536,1536)]
|
||||
|
||||
#AlexNet sizes
|
||||
sizes[isc.templates.gemm_nn] = [(3025,96,363),
|
||||
(729,128,1200),
|
||||
|
Reference in New Issue
Block a user