Tune: cleaning

This commit is contained in:
Philippe Tillet
2015-08-26 11:35:11 -04:00
parent 8e6a5eb415
commit 5d8d4f2c17
3 changed files with 64 additions and 62 deletions

View File

@@ -81,8 +81,18 @@ class IsaacApp(App):
sc.driver.default.queue_properties = sc.driver.PROFILING_ENABLE sc.driver.default.queue_properties = sc.driver.PROFILING_ENABLE
self.logger.info('Using ' + device.name) self.logger.info('Using ' + device.name)
self.logger.info('') self.logger.info('')
tuner = Tuner(self.logger, device, sc.templates.axpy, '')
tid = thread.start_new_thread(Tuner.run, (tuner,)) def run():
operations = [('blas1', (sc.templates.axpy,)),
('blas2', (sc.templates.ger, sc.templates.gemv_n, sc.templates.gemv_t)),
('blas3', (sc.templates.gemm_nn, sc.templates.gemm_tn, sc.templates.gemm_nt, sc.templates.gemm_tt))]
for opclass, optype in operations:
for op in optype:
tuner = Tuner(self.logger, device, op, '')
tuner.run(self.config.get('autotuning', opclass).lower())
self.logger.info('')
tid = thread.start_new_thread(run, ())
else: else:
pass pass
button.text = 'Running...' if button.text == 'Run' else button.text button.text = 'Running...' if button.text == 'Run' else button.text

View File

@@ -14,8 +14,8 @@ from numpy import cumsum
import tools import tools
from tools import profile_execution_failure from tools import profile_execution_failure
fetch_types = [sc.templates.fetching_policy_type.FETCH_FROM_LOCAL, fetch_types = [sc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_CONTIGUOUS,
sc.templates.fetching_policy_type.FETCH_FROM_LOCAL, sc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_STRIDED,
sc.templates.fetching_policy_type.FETCH_FROM_LOCAL, sc.templates.fetching_policy_type.FETCH_FROM_LOCAL,
sc.templates.fetching_policy_type.FETCH_FROM_LOCAL] sc.templates.fetching_policy_type.FETCH_FROM_LOCAL]

View File

@@ -31,15 +31,15 @@ class Tuner:
self.device = device self.device = device
self.operation = operation self.operation = operation
self.json_path = json_path self.json_path = json_path
def pprint_datapoint(self, x, y): def pprint_datapoint(self, x, y):
if self.logger: if self.logger:
self.logger.info(', '.join(map(str, x)) + ': ' + str(int(max(y))) + ' ' + tools.metric_name_of(self.operation)) self.logger.info(', '.join(map(str, x)) + ': ' + str(int(max(y))) + ' ' + tools.metric_name_of(self.operation))
def run(self, levels = {'BLAS1': 'intermediate', 'BLAS2':'intermediate', 'BLAS3':'intermediate'}): def run(self, level = 'intermediate'):
for key, level in levels.iteritems():
assert key in ['BLAS1', 'BLAS2', 'BLAS3'] assert level in ['simple', 'intermediate', 'full']
assert level in ['simple', 'intermediate', 'full']
device = self.device device = self.device
operation = self.operation operation = self.operation
@@ -48,63 +48,55 @@ class Tuner:
if self.logger: if self.logger:
self.logger.info('Now tuning ' + operation.__name__.replace('_','-').upper() + '...') self.logger.info('Now tuning ' + operation.__name__.replace('_','-').upper() + '...')
sizes = {}
#BLAS1 training sizes #BLAS1 training sizes
if levels['BLAS1']=='simple': if operation in [sc.templates.axpy, sc.templates.dot]:
blas1_sizes = [(1e7,)] if level=='simple':
elif levels['BLAS1']=='intermediate': sizes = [(1e7,)]
blas1_sizes = [(x,) for x in tools.expspace(1e3, 1e8, 10)] elif level=='intermediate':
else: sizes = [(x,) for x in tools.expspace(1e3, 1e8, 10)]
blas1_sizes = [(x,) for x in tools.expspace(1e3, 1e8, 30)] else:
sizes[sc.templates.axpy] = blas1_sizes sizes = [(x,) for x in tools.expspace(1e3, 1e8, 30)]
sizes[sc.templates.dot] = blas1_sizes
#BLAS2 training sizes #BLAS2 training sizes
if levels['BLAS2']=='simple': if operation in [sc.templates.ger, sc.templates.gemv_n, sc.templates.gemv_t]:
blas2_sizes = [(1536, 1536)] if level=='simple':
elif levels['BLAS2']=='intermediate': sizes = [(1536, 1536)]
blas2_sizes = [(1000,256), elif level=='intermediate':
(4096,256), sizes = [(1000,256),
(256, 1000), (4096,256),
(256, 4096), (256, 1000),
(169,256), (256, 4096),
(169, 384), (169,256),
(729,256), (169, 384),
(3025,96)] (729,256),
else: (3025,96)]
blas2_sizes = product(pow2range(4,17), pow2range(4,17)) else:
sizes[sc.templates.ger] = blas2_sizes sizes = product(pow2range(4,17), pow2range(4,17))
sizes[sc.templates.gemv_n] = blas2_sizes
sizes[sc.templates.gemv_t] = blas2_sizes
#BLAS3 training sizes #BLAS3 training sizes
if levels['BLAS3']=='simple': if operation in [sc.templates.gemm_nn, sc.templates.gemm_nt, sc.templates.gemm_tn, sc.templates.gemm_tt]:
blas3_sizes = [(1536,1536,1536)] if level=='simple':
elif levels['BLAS3']=='intermediate': sizes = [(1536,1536,1536)]
blas3_sizes = [(32, 32, 16000), elif level=='intermediate':
(3025,96,363), sizes = [(32, 32, 16000),
(729,128,1200), (3025,96,363),
(169,384,2304), (729,128,1200),
(169,192,1728), (169,384,2304),
(169,128,1728), (169,192,1728),
(169,1728,128), (169,128,1728),
(169,1728,192), (169,1728,128),
(169,2304,384), (169,1728,192),
(729,1200,128), (169,2304,384),
(1728,128,169), (729,1200,128),
(1728,192,169), (1728,128,169),
(2304,384,169), (1728,192,169),
(1200,128,729), (2304,384,169),
(363,96,3025)] (1200,128,729),
elif levels['BLAS3']=='full': (363,96,3025)]
blas3_sizes = product(pow2range(5, 12), pow2range(5, 12), pow2range(5, 15)) elif level=='full':
sizes[sc.templates.gemm_nn] = blas3_sizes sizes = product(pow2range(5, 12), pow2range(5, 12), pow2range(5, 15))
sizes[sc.templates.gemm_tn] = blas3_sizes
sizes[sc.templates.gemm_nt] = blas3_sizes
sizes[sc.templates.gemm_tt] = blas3_sizes
#Remove duplicates #Remove duplicates and or too small/big tuples
sizes = unique(list(sizes[operation]))
sizes = [x for x in sizes if 1e-4 <= tools.memory_footprint(operation, x) <= 1e-1] sizes = [x for x in sizes if 1e-4 <= tools.memory_footprint(operation, x) <= 1e-1]
#Training data #Training data