Runtime: More progress towards cuBLAS integration

This commit is contained in:
Philippe Tillet
2016-10-04 01:02:43 -04:00
parent fb9669a34d
commit ffb9548b6a
18 changed files with 170 additions and 210 deletions

View File

@@ -84,6 +84,7 @@ public:
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const = 0; virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const = 0;
virtual int is_invalid(expression_tree const & expressions, driver::Device const & device) const = 0; virtual int is_invalid(expression_tree const & expressions, driver::Device const & device) const = 0;
virtual void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & expressions) = 0; virtual void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & expressions) = 0;
virtual expression_type type() const = 0;
std::string generate(std::string const & suffix, expression_tree const & expressions, driver::Device const & device); std::string generate(std::string const & suffix, expression_tree const & expressions, driver::Device const & device);
std::shared_ptr<base> getptr(); std::shared_ptr<base> getptr();
}; };

View File

@@ -38,6 +38,7 @@ public:
elementwise_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch); elementwise_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
expression_type type() const;
private: private:
unsigned int ng_; unsigned int ng_;
fetch_type fetch_; fetch_type fetch_;

View File

@@ -39,6 +39,7 @@ public:
elementwise_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch); elementwise_2d(unsigned int vwidth, unsigned int ls0, unsigned int ls1, unsigned int ng0, unsigned int ng1, fetch_type fetch);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
expression_type type() const;
private: private:
unsigned int ng0_; unsigned int ng0_;
unsigned int ng1_; unsigned int ng1_;

View File

@@ -39,6 +39,7 @@ public:
int is_invalid(expression_tree const &, driver::Device const &) const; int is_invalid(expression_tree const &, driver::Device const &) const;
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const &, std::string const &, runtime::execution_handler const & h); void enqueue(driver::CommandQueue & queue, driver::Program const &, std::string const &, runtime::execution_handler const & h);
expression_type type() const;
private: private:
const char A_trans_; const char A_trans_;
const char B_trans_; const char B_trans_;
@@ -62,6 +63,7 @@ public:
, int_t lf0, int_t lf1, char A_trans, char B_trans); , int_t lf0, int_t lf1, char A_trans, char B_trans);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & h); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & h);
expression_type type() const;
private: private:
//Parameters //Parameters

View File

@@ -43,6 +43,8 @@ public:
reduce_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch); reduce_1d(unsigned int vwidth, unsigned int ls, unsigned int ng, fetch_type fetch);
std::vector<int_t> input_sizes(expression_tree const & expressions) const; std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
expression_type type() const;
private: private:
unsigned int ng_; unsigned int ng_;
fetch_type fetch_; fetch_type fetch_;

View File

@@ -44,6 +44,7 @@ private:
public: public:
virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const; virtual std::vector<int_t> input_sizes(expression_tree const & expressions) const;
void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &); void enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const &);
expression_type type() const;
private: private:
unsigned int ng0_; unsigned int ng0_;
unsigned int ng1_; unsigned int ng1_;

View File

@@ -53,7 +53,7 @@ public:
public: public:
value_type(expression_type, numeric_type, predictors::random_forest const &, std::vector< std::shared_ptr<templates::base> > const &, driver::CommandQueue const &); value_type(expression_type, numeric_type, predictors::random_forest const &, std::vector< std::shared_ptr<templates::base> > const &, driver::CommandQueue const &);
value_type(expression_type, numeric_type, std::shared_ptr<templates::base> const &, driver::CommandQueue const &); value_type(numeric_type, std::shared_ptr<templates::base> const &, driver::CommandQueue const &);
void execute(runtime::execution_handler const &); void execute(runtime::execution_handler const &);
templates_container const & templates() const; templates_container const & templates() const;

View File

@@ -43,6 +43,9 @@ int elementwise_1d::is_invalid_impl(driver::Device const &, expression_tree cons
return TEMPLATE_VALID; return TEMPLATE_VALID;
} }
expression_type elementwise_1d::type() const
{ return ELEMENTWISE_1D; }
std::string elementwise_1d::generate_impl(std::string const & suffix, expression_tree const & tree, driver::Device const & device, symbolic::symbols_table const & symbols) const std::string elementwise_1d::generate_impl(std::string const & suffix, expression_tree const & tree, driver::Device const & device, symbolic::symbols_table const & symbols) const
{ {
driver::backend_type backend = device.backend(); driver::backend_type backend = device.backend();

View File

@@ -42,6 +42,9 @@ int elementwise_2d::is_invalid_impl(driver::Device const &, expression_tree cons
return TEMPLATE_VALID; return TEMPLATE_VALID;
} }
expression_type elementwise_2d::type() const
{ return ELEMENTWISE_2D; }
std::string elementwise_2d::generate_impl(std::string const & suffix, expression_tree const & tree, driver::Device const & device, symbolic::symbols_table const & symbols) const std::string elementwise_2d::generate_impl(std::string const & suffix, expression_tree const & tree, driver::Device const & device, symbolic::symbols_table const & symbols) const
{ {
std::string init0, upper_bound0, inc0, init1, upper_bound1, inc1; std::string init0, upper_bound0, inc0, init1, upper_bound1, inc1;

View File

@@ -50,16 +50,11 @@ std::vector<int_t> infos(expression_tree const & tree, symbolic::preset::gemm::a
} }
/* ------------------ CUBLAS ------------------ */ /* ------------------ CUBLAS ------------------ */
bool cublas_gemm::init()
{
return driver::dispatch::cublasinit();
}
cublas_gemm::cublas_gemm(char A_trans, char B_trans): A_trans_(A_trans), B_trans_(B_trans), init_(driver::dispatch::cublasinit()) cublas_gemm::cublas_gemm(char A_trans, char B_trans): A_trans_(A_trans), B_trans_(B_trans), init_(driver::dispatch::cublasinit())
{ } { }
int cublas_gemm::is_invalid(expression_tree const &, driver::Device const & device) const int cublas_gemm::is_invalid(expression_tree const &, driver::Device const & device) const
{ return init_ && device.backend()==driver::CUDA; } { return (init_ && device.backend()==driver::CUDA)?0:-1; }
std::vector<int_t> cublas_gemm::input_sizes(expression_tree const & expressions) const std::vector<int_t> cublas_gemm::input_sizes(expression_tree const & expressions) const
{ {
@@ -67,9 +62,21 @@ std::vector<int_t> cublas_gemm::input_sizes(expression_tree const & expressions)
return infos((expression_tree&)expressions, dummy, A_trans_); return infos((expression_tree&)expressions, dummy, A_trans_);
} }
expression_type cublas_gemm::type() const
{
if(A_trans_=='N' && B_trans_=='N')
return GEMM_NN;
else if(A_trans_=='T' && B_trans_=='N')
return GEMM_TN;
else if(A_trans_=='N' && B_trans_=='T')
return GEMM_NT;
else
return GEMM_TT;
}
void cublas_gemm::enqueue(driver::CommandQueue & queue, driver::Program const &, std::string const &, runtime::execution_handler const & control) void cublas_gemm::enqueue(driver::CommandQueue & queue, driver::Program const &, std::string const &, runtime::execution_handler const & control)
{ {
namespace drv = driver;; namespace drv = driver;
//Get GEMM info //Get GEMM info
symbolic::preset::gemm::args args; symbolic::preset::gemm::args args;
std::vector<int_t> MNK = infos(control.x(), args, A_trans_); std::vector<int_t> MNK = infos(control.x(), args, A_trans_);
@@ -115,6 +122,19 @@ unsigned int gemm::lmem_usage(expression_tree const & expression) const
return N*size_of(expression.dtype()); return N*size_of(expression.dtype());
} }
expression_type gemm::type() const
{
if(A_trans_=='N' && B_trans_=='N')
return GEMM_NN;
else if(A_trans_=='T' && B_trans_=='N')
return GEMM_TN;
else if(A_trans_=='N' && B_trans_=='T')
return GEMM_NT;
else
return GEMM_TT;
}
unsigned int gemm::registers_usage(expression_tree const & expression) const unsigned int gemm::registers_usage(expression_tree const & expression) const
{ {
unsigned int N = mS_ * nS_ + mS_ * kS_ + kS_ * nS_; unsigned int N = mS_ * nS_ + mS_ * kS_ + kS_ * nS_;

View File

@@ -55,6 +55,9 @@ unsigned int reduce_1d::temporary_workspace(expression_tree const &) const
return 0; return 0;
} }
expression_type reduce_1d::type() const
{ return REDUCE_1D; }
inline void reduce_1d::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<symbolic::reduce_1d*> exprs, inline void reduce_1d::reduce_1d_local_memory(kernel_generation_stream & stream, unsigned int size, std::vector<symbolic::reduce_1d*> exprs,
std::string const & buf_str, std::string const & buf_value_str, driver::backend_type) const std::string const & buf_str, std::string const & buf_value_str, driver::backend_type) const
{ {

View File

@@ -290,6 +290,14 @@ std::vector<int_t> reduce_2d::input_sizes(expression_tree const & tree) const
return {shape[0], shape[1]}; return {shape[0], shape[1]};
} }
expression_type reduce_2d::type() const
{
if(reduction_type_==REDUCE_ROWS)
return REDUCE_2D_ROWS;
else
return REDUCE_2D_COLS;
}
void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & control) void reduce_2d::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & control)
{ {
expression_tree const & tree = control.x(); expression_tree const & tree = control.x();

View File

@@ -77,7 +77,7 @@ profiles::value_type::value_type(expression_type etype, numeric_type dtype, pred
} }
profiles::value_type::value_type(expression_type etype, numeric_type dtype, std::shared_ptr<templates::base> const & tp, driver::CommandQueue const & queue) : templates_(1,tp), queue_(queue), cache_(driver::backend::programs::get(queue,etype,dtype)) profiles::value_type::value_type(numeric_type dtype, std::shared_ptr<templates::base> const & tp, driver::CommandQueue const & queue) : templates_(1,tp), queue_(queue), cache_(driver::backend::programs::get(queue,tp->type(),dtype))
{ {
cache_.clear(); cache_.clear();
} }
@@ -197,7 +197,7 @@ void profiles::import(std::string const & str, driver::CommandQueue const & queu
result[{etype, dtype}] = std::make_shared<value_type>(etype, dtype, predictor, templates, queue); result[{etype, dtype}] = std::make_shared<value_type>(etype, dtype, predictor, templates, queue);
} }
else else
result[{etype, dtype}] = std::make_shared<value_type>(etype, dtype, templates[0], queue); result[{etype, dtype}] = std::make_shared<value_type>(dtype, templates[0], queue);
} }
} }
} }

View File

@@ -86,31 +86,5 @@ namespace tools
throw; throw;
} }
} }
inline sc::expression_type extract_template_type(bp::object const & odtype)
{
std::string name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))();
if(name=="class")
name = bp::extract<std::string>(odtype.attr("__name__"))();
else
name = bp::extract<std::string>(odtype.attr("__class__").attr("__name__"))();
if(name=="elementwise_1d") return sc::ELEMENTWISE_1D;
else if(name=="elementwise_2d") return sc::ELEMENTWISE_2D;
else if(name=="reduce_1d") return sc::REDUCE_1D;
else if(name=="reduce_2d_rows") return sc::REDUCE_2D_ROWS;
else if(name=="reduce_2d_cols") return sc::REDUCE_2D_COLS;
else if(name=="gemm_nn") return sc::GEMM_NN;
else if(name=="gemm_tn") return sc::GEMM_TN;
else if(name=="gemm_nt") return sc::GEMM_NT;
else if(name=="gemm_tt") return sc::GEMM_TT;
else
{
PyErr_SetString(PyExc_TypeError, "Template type not understood");
bp::throw_error_already_set();
throw;
}
}
} }
#endif #endif

View File

@@ -106,7 +106,7 @@ namespace detail
std::shared_ptr<rt::profiles::value_type> construct_model(bp::object const & tp, bp::object dtype, sc::driver::CommandQueue & queue) std::shared_ptr<rt::profiles::value_type> construct_model(bp::object const & tp, bp::object dtype, sc::driver::CommandQueue & queue)
{ {
tpt::base* raw = bp::extract<tpt::base*>(tp); tpt::base* raw = bp::extract<tpt::base*>(tp);
return std::make_shared<rt::profiles::value_type>(tools::extract_template_type(tp), tools::extract_dtype(dtype), raw->getptr(), queue); return std::make_shared<rt::profiles::value_type>(tools::extract_dtype(dtype), raw->getptr(), queue);
} }
std::shared_ptr<sc::array> std::shared_ptr<sc::array>
@@ -219,9 +219,9 @@ namespace detail
{ {
static rt::profiles::value_type& get_item(rt::profiles::map_type& container, bp::tuple i_) static rt::profiles::value_type& get_item(rt::profiles::map_type& container, bp::tuple i_)
{ {
sc::expression_type expression = tools::extract_template_type(i_[0]); tpt::base* tpt = bp::extract<tpt::base*>(i_[0]);
sc::numeric_type dtype = tools::extract_dtype(i_[1]); sc::numeric_type dtype = tools::extract_dtype(i_[1]);
rt::profiles::map_type::iterator i = container.find(std::make_pair(expression, dtype)); rt::profiles::map_type::iterator i = container.find(std::make_pair(tpt->type(), dtype));
if (i == container.end()) if (i == container.end())
{ {
PyErr_SetString(PyExc_KeyError, "Invalid key"); PyErr_SetString(PyExc_KeyError, "Invalid key");
@@ -232,9 +232,9 @@ namespace detail
static void set_item(rt::profiles::map_type& container, bp::tuple i_, rt::profiles::value_type const & v) static void set_item(rt::profiles::map_type& container, bp::tuple i_, rt::profiles::value_type const & v)
{ {
sc::expression_type expression = tools::extract_template_type(i_[0]); tpt::base* tpt = bp::extract<tpt::base*>(i_[0]);
sc::numeric_type dtype = tools::extract_dtype(i_[1]); sc::numeric_type dtype = tools::extract_dtype(i_[1]);
container[std::make_pair(expression, dtype)].reset(new rt::profiles::value_type(v)); container[std::make_pair(tpt->type(), dtype)].reset(new rt::profiles::value_type(v));
} }
}; };
} }

View File

@@ -39,28 +39,6 @@ fetch_types = [sc.templates.fetch_type.FETCH_FROM_GLOBAL_CONTIGUOUS,
sc.templates.fetch_type.FETCH_FROM_LOCAL, sc.templates.fetch_type.FETCH_FROM_LOCAL,
sc.templates.fetch_type.FETCH_FROM_LOCAL] sc.templates.fetch_type.FETCH_FROM_LOCAL]
def exhaustive(template, sizes, context):
tree, _ = tools.tree_of(template, sizes, context)
metric = tools.metric_of(template)
nbits = tools.genetic_infos_of(template)['nbits']
categorical = tools.genetic_infos_of(template)['categorical']
ranges = [range(2**x) for x in nbits]
ranges = list(product(*ranges))
timings = {}
best = None
for idx, r in enumerate(ranges):
parameters = tuple([fetch_types[x] if i in categorical else 2**x for i,x in enumerate(r)])
try:
time = tools.benchmark(template, parameters, tree)
if not best or time < best[1]:
best = parameters, time
except profile_execution_failure:
pass
if best:
stdout.write('%.2f %% | Best %.2f [ for %s ]\r'%(float(idx*100)/len(ranges),metric(sizes, best[1]), best[0]))
return best[0]
class GeneticOptimizer: class GeneticOptimizer:
def __init__(self, logger, naccept=500, niter=1000, cxpb=.4, mutpb=.4, popsize=10, progress_bar = None): def __init__(self, logger, naccept=500, niter=1000, cxpb=.4, mutpb=.4, popsize=10, progress_bar = None):
@@ -105,7 +83,10 @@ class GeneticOptimizer:
def evaluate(genome): def evaluate(genome):
idx = tuple(genome) idx = tuple(genome)
if idx not in cache: if idx not in cache:
cache[idx] = tools.benchmark(template, decode(genome), tree) time = tools.benchmark(template, template(*decode(genome)), tree)
if time == float('inf'):
return time,
cache[idx] = time
self.progress_bar.update(max(len(cache), it), self.niter, decode(min(cache, key=cache.get)), metric(sizes, min(cache.values()))) self.progress_bar.update(max(len(cache), it), self.niter, decode(min(cache, key=cache.get)), metric(sizes, min(cache.values())))
return cache[idx], return cache[idx],
@@ -132,11 +113,9 @@ class GeneticOptimizer:
genome = encode(prior if prior else list(initializer.next())) genome = encode(prior if prior else list(initializer.next()))
while len(population) < self.popsize: while len(population) < self.popsize:
individual = creator.Individual(genome) individual = creator.Individual(genome)
try: individual.fitness.values = toolbox.evaluate(genome)
individual.fitness.values = toolbox.evaluate(genome) if max(individual.fitness.values) != float('inf'):
population += [individual] population += [individual]
except profile_execution_failure:
pass
genome = encode(list(initializer.next())) genome = encode(list(initializer.next()))
hof.update(population) hof.update(population)
@@ -146,26 +125,25 @@ class GeneticOptimizer:
#Generate offspring #Generate offspring
offspring = [] offspring = []
while len(offspring) < self.popsize: while len(offspring) < self.popsize:
try: op_choice = random.random()
op_choice = random.random() #Cross-over
#Cross-over if op_choice < self.cxpb:
if op_choice < self.cxpb: ind1, ind2 = map(toolbox.clone, random.sample(population, 2))
ind1, ind2 = map(toolbox.clone, random.sample(population, 2)) ind1, ind2 = toolbox.mate(ind1, ind2)
ind1, ind2 = toolbox.mate(ind1, ind2) ind = ind1
ind = ind1 toolbox.evaluate(ind)
toolbox.evaluate(ind) if max(ind.fitness.values) != float('inf'):
offspring += [ind] offspring += [ind]
#Mutation #Mutation
elif op_choice < self.cxpb + self.mutpb: elif op_choice < self.cxpb + self.mutpb:
ind = toolbox.clone(random.choice(population)) ind = toolbox.clone(random.choice(population))
ind, = toolbox.mutate(ind, 1.0/offsets[-1]) ind, = toolbox.mutate(ind, 1.0/offsets[-1])
toolbox.evaluate(ind) toolbox.evaluate(ind)
if max(ind.fitness.values) != float('inf'):
offspring += [ind] offspring += [ind]
#Reproduction #Reproduction
else: else:
offspring += [random.choice(population)] offspring += [random.choice(population)]
except profile_execution_failure:
pass
#Update fitnesses #Update fitnesses
fitnesses = toolbox.map(toolbox.evaluate, offspring) fitnesses = toolbox.map(toolbox.evaluate, offspring)
@@ -195,9 +173,8 @@ def is_local_optimum(parameters, template, sizes, context):
sweep_over = [0,1,2,3,4] sweep_over = [0,1,2,3,4]
#Evaluate the provided parameters guess #Evaluate the provided parameters guess
try: reference = tools.benchmark(template, template(*parameters), tree)
reference = tools.benchmark(template, parameters, tree) if isinf(reference):
except profile_execution_failure:
return False return False
#Latency bound -- ignore #Latency bound -- ignore
@@ -210,12 +187,9 @@ def is_local_optimum(parameters, template, sizes, context):
for x in product(*domain): for x in product(*domain):
if x==parameters: if x==parameters:
pass pass
try: time = tools.benchmark(template, template(*x), tree)
time = tools.benchmark(template, x, tree) if time/reference < .98:
if time/reference < .98: return False
return False
except profile_execution_failure:
pass
return True return True

View File

@@ -40,15 +40,18 @@ def linspace(a, b, n=100):
def expspace(a,b,N,r=128): def expspace(a,b,N,r=128):
return [int(ceil(exp(x)/r)*r) for x in linspace(log(a), log(b), N)] return [int(ceil(exp(x)/r)*r) for x in linspace(log(a), log(b), N)]
def benchmark(template, setting, tree): def benchmark(operation, template, tree):
queue = tree.context.queues[0] queue = tree.context.queues[0]
queue.profiles[template, sc.float32] = sc.profile(template(*setting), sc.float32, queue) queue.profiles[template, sc.float32] = sc.profile(template, sc.float32, queue)
times = [] times = []
total = 0 total = 0
i = 0 i = 0
#Warm-up #Warm-up
z, events = sc.driver.enqueue(tree) try:
tree.context.queues[0].synchronize() z, events = sc.driver.enqueue(tree)
tree.context.queues[0].synchronize()
except profile_execution_failure:
return float("inf")
#Time #Time
while total < 1e-1: while total < 1e-1:
start = time() start = time()
@@ -119,6 +122,16 @@ def metric_name_of(template):
return 'GFLOPS' return 'GFLOPS'
return 'GB/S' return 'GB/S'
def external_profiles(template):
if template is sc.templates.gemm_nn:
return [sc.templates.cublas_gemm('N', 'N')]
elif template is sc.templates.gemm_tn:
return [sc.templates.cublas_gemm('T', 'N')]
elif template is sc.templates.gemm_nt:
return [sc.templates.cublas_gemm('N', 'T')]
elif template is sc.templates.gemm_tt:
return [sc.templates.cublas_gemm('T', 'T')]
def genetic_infos_of(template): def genetic_infos_of(template):
if issubclass(template, sc.templates.elementwise_1d): if issubclass(template, sc.templates.elementwise_1d):
return {'categorical': [3], 'nbits': [3,4,4,2] } return {'categorical': [3], 'nbits': [3,4,4,2] }

View File

@@ -69,73 +69,52 @@ class Tuner:
#BLAS1 training sizes #BLAS1 training sizes
if operation in [sc.templates.elementwise_1d, sc.templates.reduce_1d]: if operation in [sc.templates.elementwise_1d, sc.templates.reduce_1d]:
if level=='simple': sizes = [(x,) for x in tools.expspace(1e3, 1e8, 20)]
sizes = [(10000000,)]
elif level=='intermediate':
sizes = [(x,) for x in tools.expspace(1e3, 1e8, 10)]
else:
sizes = [(x,) for x in tools.expspace(1e3, 1e8, 100)]
#BLAS2 training sizes #BLAS2 training sizes
if operation in [sc.templates.elementwise_2d, sc.templates.reduce_2d_rows, sc.templates.reduce_2d_cols]: if operation in [sc.templates.elementwise_2d, sc.templates.reduce_2d_rows, sc.templates.reduce_2d_cols]:
if level=='simple': sizes = []
sizes = [(1536, 1536)] #Square
elif level=='intermediate': for N in [896, 1760, 2048, 2560]:
sizes = [] sizes += [(N, N)]
#Square #Tall and Skinny
for N in [896, 1760, 2048, 2560]: for M in [16, 32, 64, 128]:
sizes += [(N, N)] for N in [1024, 4096, 16384, 65536, 262144]:
#Tall and Skinny sizes += [(M, N)]
for M in [16, 32, 64, 128]: sizes += [(N, M)]
for N in [1024, 4096, 16384, 65536, 262144]:
sizes += [(M, N)]
sizes += [(N, M)]
else:
sizes = product(pow2range(4,17), pow2range(4,17))
#BLAS3 training sizes #BLAS3 training sizes
if operation in [sc.templates.gemm_nn, sc.templates.gemm_nt, sc.templates.gemm_tn, sc.templates.gemm_tt]: if operation in [sc.templates.gemm_nn, sc.templates.gemm_nt, sc.templates.gemm_tn, sc.templates.gemm_tt]:
if level=='simple': sizes = []
sizes = [(2560,2560,2560)] #Square
elif level=='intermediate': for N in [896, 1760, 2048, 2560]:
sizes = [] sizes += [(N, N, N)]
#Square #LaPack
for N in [896, 1760, 2048, 2560]: for N in [896, 1760, 2048, 2560]:
sizes += [(N, N, N)] for K in [16, 32, 64, 128]:
#LaPack sizes += [(N, N, K)]
for N in [896, 1760, 2048, 2560]: #Covariance
for K in [16, 32, 64, 128]: for N in [16, 32, 64, 128]:
sizes += [(N, N, K)] for K in [16000,32000,64000,128000]:
#Covariance sizes += [(N, N, K)]
for N in [16, 32, 64, 128]: #DeepSpeech
for K in [16000,32000,64000,128000]: for M in [1760, 2048, 2560]:
sizes += [(N, N, K)] for N in [16, 32, 64, 128, M]:
#DeepSpeech sizes += [(M, N, M)]
for M in [1760, 2048, 2560]:
for N in [16, 32, 64, 128, M]:
sizes += [(M, N, M)]
elif level=='full':
sizes = product(pow2range(5, 12), pow2range(5, 12), pow2range(5, 17))
#Remove duplicates and or too small/big tuples
sizes = [x for x in sizes if 1e-4 <= tools.memory_footprint(operation, x) <= 2e-1]
#Training data #Training data
performance = tools.metric_of(operation) performance = tools.metric_of(operation)
profiles, X, Y = [], [], [] profiles, X, Y = [], [], []
#Restore previous run #Restore progress
savepath = os.path.join('save', operation.__name__) savepath = os.path.join('save', operation.__name__)
if not os.path.exists(savepath): if not os.path.exists(savepath):
os.makedirs(savepath) os.makedirs(savepath)
try: try:
with open(os.path.join(savepath, 'X.csv')) as f: with open(os.path.join(savepath, 'X.csv')) as f:
X = [tuple(map(int, row)) for row in csv.reader(f, delimiter=',')] X = [tuple(map(int, row)) for row in csv.reader(f, delimiter=',')]
with open(os.path.join(savepath, 'Y.csv')) as f: with open(os.path.join(savepath, 'Y.csv')) as f:
Y = [map(float, row) for row in csv.reader(f, delimiter=',')] Y = [map(float, row) for row in csv.reader(f, delimiter=',')]
with open(os.path.join(savepath, 'profiles.csv')) as f: with open(os.path.join(savepath, 'profiles.csv')) as f:
def mmap(x): def mmap(x):
if x=='FETCH_FROM_LOCAL': if x=='FETCH_FROM_LOCAL':
@@ -149,94 +128,69 @@ class Tuner:
except: except:
pass pass
##### Exploration ##### #Tuning
for idx, x in enumerate(sizes): for idx, x in enumerate(sizes):
#Create new line on log
if idx>0: if idx>0:
self.progress_bar.set_finished() self.progress_bar.set_finished()
self.progress_bar.set_prefix(', '.join(map(str, x))) self.progress_bar.set_prefix(', '.join(map(str, x)))
#Skip if saved #Skip if already saved
if x in X: if x in X:
row = Y[X.index(x)] row = Y[X.index(x)]
self.progress_bar.update(1, 1, profiles[argmax(row)], max(row)) self.progress_bar.update(1, 1, profiles[argmax(row)], max(row))
continue continue
#Check if the current best prediction is not a local optimum
idx = len(X)
nparams = len(profiles)
tree, operands = tools.tree_of(operation, x, context) tree, operands = tools.tree_of(operation, x, context)
if idx==0: #Check if GA needs to run (i.e., current best prediction is not a local optimum)
retune = True tune = True
predicted = None best = None
else: if idx > 0:
if nparams==1: dim = min(10, idx+1)
predicted = profiles[0] model = RandomForestRegressor(dim, dim).fit(X, Y)
else: predictions = model.predict(x)[0]
clf = RandomForestRegressor(min(10, idx+1), max_depth=min(10, idx+1)).fit(X, Y) for idx in (-predictions).argsort():
#clf, nrmse = model.train(X, Y, profiles) ts = tools.benchmark(operation, operation(*profiles[idx]), tree)
predperf = clf.predict(x)[0] if np.isfinite(ts):
best = (-predperf).argsort() break
perf = [] if np.isfinite(ts):
for b in best: best = profiles[idx]
try: tune = not optimize.is_local_optimum(predicted, operation, x, context)
perf += [performance(x, tools.benchmark(operation, profiles[b], tree))]
break
except profile_execution_failure:
pass
if perf:
predicted = profiles[best[argmax(perf)]]
retune = not optimize.is_local_optimum(predicted, operation, x, context)
else:
retune = True
predicted = None
#Retune if necessary #Retune if necessary
if retune: if tune:
optimizer = optimize.GeneticOptimizer(self.logger, naccept=1000, niter=1000, cxpb=.4, mutpb=.4, popsize=20, progress_bar = self.progress_bar) optimizer = optimize.GeneticOptimizer(self.logger, naccept=1000, niter=1000, cxpb=.4, mutpb=.4, popsize=20, progress_bar = self.progress_bar)
new = optimizer.run(operation, x, context, prior=predicted)[0] best = optimizer.run(operation, x, context, prior=best)[0]
if new not in profiles: if best not in profiles:
profiles.append(new) profiles.append(best)
if idx > 0: for xx,yy in zip(X, Y):
for xx,yy in zip(X, Y): tree, _operands = tools.tree_of(operation, xx, context)
_tree, _operands = tools.tree_of(operation, xx, context) time = tools.benchmark(operation, best, _tree)
try: yy.append(performance(xx, time))
time = tools.benchmark(operation, new, _tree) #Update dataset
perf = performance(xx, time)
except profile_execution_failure:
perf = 0
yy.append(0 if isinf(perf) else perf)
##### Training #####
y = []
fastest = max(predperf) if nparams > 1 else None
for ip, p in enumerate(profiles):
try:
perf = 0 if fastest and ip < nparams and predperf[ip]/fastest < .1 else performance(x,tools.benchmark(operation, p, tree))
except profile_execution_failure:
perf = 0
y.append(0 if isinf(perf) else perf)
X.append(x) X.append(x)
y = [performance(x,tools.benchmark(operation, prf, tree)) for prf in profiles]
Y.append(y) Y.append(y)
#Save data #Save data
for (fname, data) in zip(['X.csv', 'Y.csv', 'profiles.csv'], [X, Y, profiles]): for (fname, data) in zip(['X.csv', 'Y.csv', 'profiles.csv'], [X, Y, profiles]):
with open(os.path.join(savepath, fname), 'wb') as f: with open(os.path.join(savepath, fname), 'wb') as f:
csv.writer(f).writerows(data) csv.writer(f).writerows(data)
#print performance info in case no tuning was done #print performance info in case no tuning was done
if not retune: if not tune:
row = Y[X.index(x)] row = Y[X.index(x)]
self.progress_bar.update(1, 1, profiles[argmax(row)], max(row)) self.progress_bar.update(1, 1, profiles[argmax(row)], max(row))
self.progress_bar.set_finished() self.progress_bar.set_finished()
#Remove unused profiles #Adding external profiles
#~ for prf in tools.external_profiles(operation):
#~ x = [1024, 1024, 1024]
#~ tree, operands = tools.tree_of(operation, x, context)
#~ print performance(x,tools.benchmark(operation, prf, tree))
#Pruning of useless profiles
if len(Y[0]) > 1: if len(Y[0]) > 1:
unused = np.where(np.bincount(np.argmax(Y, 1))==0)[0] unused = np.where(np.bincount(np.argmax(Y, 1))==0)[0]
profiles = [p for ip,p in enumerate(profiles) if ip not in unused] profiles = [p for ip,p in enumerate(profiles) if ip not in unused]
Y = np.delete(Y, np.where(np.bincount(np.argmax(Y, 1))==0), axis=1).tolist() Y = np.delete(Y, np.where(np.bincount(np.argmax(Y, 1))==0), axis=1).tolist()
##### Exportation ##### #Exporting to JSON
json_path = tools.sanitize(device.name) + '.json' if not self.json_path else self.json_path json_path = tools.sanitize(device.name) + '.json' if not self.json_path else self.json_path
if os.path.isfile(json_path): if os.path.isfile(json_path):
json_data = json.load(open(json_path, 'r')) json_data = json.load(open(json_path, 'r'))