2015-08-26 21:05:11 -04:00
|
|
|
import argparse, logging, sys
|
2015-08-16 19:58:54 -07:00
|
|
|
import isaac as sc
|
2015-08-26 21:05:11 -04:00
|
|
|
from tune.tune import Tuner
|
|
|
|
from tune.tools import metric_name_of
|
2015-08-16 19:58:54 -07:00
|
|
|
|
|
|
|
def parse_arguments():
|
|
|
|
platforms = sc.driver.get_platforms()
|
|
|
|
devices = [d for platform in platforms for d in platform.get_devices()]
|
|
|
|
#Command line arguments
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("-d", "--device", default=0, type=int, help='Device to tune for')
|
|
|
|
parser.add_argument("-j", "--json", default='', type=str)
|
2015-08-27 22:56:05 -04:00
|
|
|
parser.add_argument('--axpy', action='store_true', help='Tune AXPY')
|
|
|
|
parser.add_argument('--gemv_n', action='store_true', help='Tune GEMV-N')
|
|
|
|
parser.add_argument('--gemv_t', action='store_true', help='Tune GEMV-T')
|
|
|
|
parser.add_argument('--gemm_nn', action='store_true', help='Tune GEMM-NN')
|
|
|
|
parser.add_argument('--gemm_tn', action='store_true', help='Tune GEMM-TN')
|
|
|
|
parser.add_argument('--gemm_nt', action='store_true', help='Tune GEMM-NT')
|
|
|
|
parser.add_argument('--gemm_tt', action='store_true', help='Tune GEMM-TT')
|
|
|
|
|
2015-08-16 19:58:54 -07:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
device = devices[int(args.device)]
|
|
|
|
print("----------------")
|
|
|
|
print("Devices available:")
|
|
|
|
print("----------------")
|
|
|
|
for (i, d) in enumerate(devices):
|
|
|
|
selected = '[' + ('x' if device==d else ' ') + ']'
|
|
|
|
print selected , '-', sc.driver.device_type_to_string(d.type), '-', d.name, 'on', d.platform.name
|
|
|
|
|
|
|
|
|
2015-08-27 22:56:05 -04:00
|
|
|
operations = ['axpy', 'gemv_n', 'gemv_t', 'gemm_nn', 'gemm_tn', 'gemm_nt', 'gemm_tt']
|
|
|
|
operations = [getattr(sc.templates,op) for op in operations if getattr(args, op)]
|
|
|
|
|
|
|
|
return (device, operations, args.json)
|
2015-08-16 19:58:54 -07:00
|
|
|
|
2015-08-26 21:05:11 -04:00
|
|
|
|
|
|
|
class ProgressBar:
|
|
|
|
|
|
|
|
def __init__(self, length, metric_name):
|
|
|
|
self.length = length
|
|
|
|
self.metric_name = metric_name
|
|
|
|
|
|
|
|
def set_prefix(self, prefix):
|
|
|
|
self.prefix = prefix
|
|
|
|
|
2015-08-27 22:56:05 -04:00
|
|
|
def update(self, i, total, x, y, complete=False):
|
2015-08-26 21:05:11 -04:00
|
|
|
percent = float(i) / total
|
|
|
|
hashes = '#' * int(round(percent * self.length))
|
|
|
|
spaces = ' ' * (self.length - len(hashes))
|
2015-08-27 22:56:05 -04:00
|
|
|
#Format of structures to print
|
|
|
|
xformat = ','.join(map(str,map(int, x)))
|
|
|
|
yformat = int(y)
|
|
|
|
percentformat = int(round(percent * 100))
|
|
|
|
sys.stdout.write(("\r" + self.prefix.ljust(10) + ": [{0}] {1: >3}% [{2} " + self.metric_name + "] ({3})").format(hashes + spaces, percentformat, yformat, xformat))
|
|
|
|
if complete:
|
|
|
|
sys.stdout.write("\n")
|
2015-08-26 21:05:11 -04:00
|
|
|
sys.stdout.flush()
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
sh = logging.StreamHandler(sys.stdout)
|
|
|
|
sh.setFormatter(logging.Formatter('%(message)s'))
|
|
|
|
sh.setLevel(logging.INFO)
|
|
|
|
logger.addHandler(sh)
|
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
2015-08-16 19:58:54 -07:00
|
|
|
sc.driver.default.queue_properties = sc.driver.PROFILING_ENABLE
|
2015-08-27 22:56:05 -04:00
|
|
|
device, operations, json = parse_arguments()
|
|
|
|
|
|
|
|
for operation in operations:
|
|
|
|
tuner = Tuner(logger, device, operation, json, ProgressBar(30, metric_name_of(operation)))
|
|
|
|
tuner.run(level='full')
|