Tune: cleaning
This commit is contained in:
@@ -81,8 +81,18 @@ class IsaacApp(App):
|
||||
sc.driver.default.queue_properties = sc.driver.PROFILING_ENABLE
|
||||
self.logger.info('Using ' + device.name)
|
||||
self.logger.info('')
|
||||
tuner = Tuner(self.logger, device, sc.templates.axpy, '')
|
||||
tid = thread.start_new_thread(Tuner.run, (tuner,))
|
||||
|
||||
def run():
|
||||
operations = [('blas1', (sc.templates.axpy,)),
|
||||
('blas2', (sc.templates.ger, sc.templates.gemv_n, sc.templates.gemv_t)),
|
||||
('blas3', (sc.templates.gemm_nn, sc.templates.gemm_tn, sc.templates.gemm_nt, sc.templates.gemm_tt))]
|
||||
for opclass, optype in operations:
|
||||
for op in optype:
|
||||
tuner = Tuner(self.logger, device, op, '')
|
||||
tuner.run(self.config.get('autotuning', opclass).lower())
|
||||
self.logger.info('')
|
||||
|
||||
tid = thread.start_new_thread(run, ())
|
||||
else:
|
||||
pass
|
||||
button.text = 'Running...' if button.text == 'Run' else button.text
|
||||
|
@@ -14,8 +14,8 @@ from numpy import cumsum
|
||||
import tools
|
||||
from tools import profile_execution_failure
|
||||
|
||||
fetch_types = [sc.templates.fetching_policy_type.FETCH_FROM_LOCAL,
|
||||
sc.templates.fetching_policy_type.FETCH_FROM_LOCAL,
|
||||
fetch_types = [sc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_CONTIGUOUS,
|
||||
sc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_STRIDED,
|
||||
sc.templates.fetching_policy_type.FETCH_FROM_LOCAL,
|
||||
sc.templates.fetching_policy_type.FETCH_FROM_LOCAL]
|
||||
|
||||
|
@@ -32,13 +32,13 @@ class Tuner:
|
||||
self.operation = operation
|
||||
self.json_path = json_path
|
||||
|
||||
|
||||
def pprint_datapoint(self, x, y):
|
||||
if self.logger:
|
||||
self.logger.info(', '.join(map(str, x)) + ': ' + str(int(max(y))) + ' ' + tools.metric_name_of(self.operation))
|
||||
|
||||
def run(self, levels = {'BLAS1': 'intermediate', 'BLAS2':'intermediate', 'BLAS3':'intermediate'}):
|
||||
for key, level in levels.iteritems():
|
||||
assert key in ['BLAS1', 'BLAS2', 'BLAS3']
|
||||
def run(self, level = 'intermediate'):
|
||||
|
||||
assert level in ['simple', 'intermediate', 'full']
|
||||
|
||||
device = self.device
|
||||
@@ -48,22 +48,21 @@ class Tuner:
|
||||
if self.logger:
|
||||
self.logger.info('Now tuning ' + operation.__name__.replace('_','-').upper() + '...')
|
||||
|
||||
sizes = {}
|
||||
#BLAS1 training sizes
|
||||
if levels['BLAS1']=='simple':
|
||||
blas1_sizes = [(1e7,)]
|
||||
elif levels['BLAS1']=='intermediate':
|
||||
blas1_sizes = [(x,) for x in tools.expspace(1e3, 1e8, 10)]
|
||||
if operation in [sc.templates.axpy, sc.templates.dot]:
|
||||
if level=='simple':
|
||||
sizes = [(1e7,)]
|
||||
elif level=='intermediate':
|
||||
sizes = [(x,) for x in tools.expspace(1e3, 1e8, 10)]
|
||||
else:
|
||||
blas1_sizes = [(x,) for x in tools.expspace(1e3, 1e8, 30)]
|
||||
sizes[sc.templates.axpy] = blas1_sizes
|
||||
sizes[sc.templates.dot] = blas1_sizes
|
||||
sizes = [(x,) for x in tools.expspace(1e3, 1e8, 30)]
|
||||
|
||||
#BLAS2 training sizes
|
||||
if levels['BLAS2']=='simple':
|
||||
blas2_sizes = [(1536, 1536)]
|
||||
elif levels['BLAS2']=='intermediate':
|
||||
blas2_sizes = [(1000,256),
|
||||
if operation in [sc.templates.ger, sc.templates.gemv_n, sc.templates.gemv_t]:
|
||||
if level=='simple':
|
||||
sizes = [(1536, 1536)]
|
||||
elif level=='intermediate':
|
||||
sizes = [(1000,256),
|
||||
(4096,256),
|
||||
(256, 1000),
|
||||
(256, 4096),
|
||||
@@ -72,16 +71,14 @@ class Tuner:
|
||||
(729,256),
|
||||
(3025,96)]
|
||||
else:
|
||||
blas2_sizes = product(pow2range(4,17), pow2range(4,17))
|
||||
sizes[sc.templates.ger] = blas2_sizes
|
||||
sizes[sc.templates.gemv_n] = blas2_sizes
|
||||
sizes[sc.templates.gemv_t] = blas2_sizes
|
||||
sizes = product(pow2range(4,17), pow2range(4,17))
|
||||
|
||||
#BLAS3 training sizes
|
||||
if levels['BLAS3']=='simple':
|
||||
blas3_sizes = [(1536,1536,1536)]
|
||||
elif levels['BLAS3']=='intermediate':
|
||||
blas3_sizes = [(32, 32, 16000),
|
||||
if operation in [sc.templates.gemm_nn, sc.templates.gemm_nt, sc.templates.gemm_tn, sc.templates.gemm_tt]:
|
||||
if level=='simple':
|
||||
sizes = [(1536,1536,1536)]
|
||||
elif level=='intermediate':
|
||||
sizes = [(32, 32, 16000),
|
||||
(3025,96,363),
|
||||
(729,128,1200),
|
||||
(169,384,2304),
|
||||
@@ -96,15 +93,10 @@ class Tuner:
|
||||
(2304,384,169),
|
||||
(1200,128,729),
|
||||
(363,96,3025)]
|
||||
elif levels['BLAS3']=='full':
|
||||
blas3_sizes = product(pow2range(5, 12), pow2range(5, 12), pow2range(5, 15))
|
||||
sizes[sc.templates.gemm_nn] = blas3_sizes
|
||||
sizes[sc.templates.gemm_tn] = blas3_sizes
|
||||
sizes[sc.templates.gemm_nt] = blas3_sizes
|
||||
sizes[sc.templates.gemm_tt] = blas3_sizes
|
||||
elif level=='full':
|
||||
sizes = product(pow2range(5, 12), pow2range(5, 12), pow2range(5, 15))
|
||||
|
||||
#Remove duplicates
|
||||
sizes = unique(list(sizes[operation]))
|
||||
#Remove duplicates and or too small/big tuples
|
||||
sizes = [x for x in sizes if 1e-4 <= tools.memory_footprint(operation, x) <= 1e-1]
|
||||
|
||||
#Training data
|
||||
|
Reference in New Issue
Block a user