Files
triton/tune/main.py

79 lines
3.4 KiB
Python
Raw Normal View History

import argparse, logging, sys
import isaac as sc
from tune.tune import Tuner
from tune.tools import metric_name_of
def parse_arguments():
platforms = sc.driver.get_platforms()
devices = [d for platform in platforms for d in platform.get_devices()]
#Command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default=0, type=int, help='Device to tune for')
parser.add_argument("-j", "--json", default='', type=str)
2015-12-16 16:34:36 -05:00
parser.add_argument('--elementwise_1d', action='store_true', help='Tune AXPY')
parser.add_argument('--reduce_1d', action='store_true', help='Tune DOT')
parser.add_argument('--elementwise_2d', action='store_true', help='Tune GER')
parser.add_argument('--reduce_2d_rows', action='store_true', help='Tune GEMV-N')
parser.add_argument('--reduce_2d_cols', action='store_true', help='Tune GEMV-T')
parser.add_argument('--matrix_product_nn', action='store_true', help='Tune GEMM-NN')
parser.add_argument('--matrix_product_tn', action='store_true', help='Tune GEMM-TN')
parser.add_argument('--matrix_product_nt', action='store_true', help='Tune GEMM-NT')
parser.add_argument('--matrix_product_tt', action='store_true', help='Tune GEMM-TT')
2015-08-27 22:56:05 -04:00
args = parser.parse_args()
device = devices[int(args.device)]
print("----------------")
print("Devices available:")
print("----------------")
for (i, d) in enumerate(devices):
selected = '[' + ('x' if device==d else ' ') + ']'
print selected , '-', sc.driver.device_type_to_string(d.type), '-', d.name, 'on', d.platform.name
2015-12-16 16:34:36 -05:00
operations = ['elementwise_1d', 'reduce_1d', 'elementwise_2d', 'reduce_2d_rows', 'reduce_2d_cols', 'matrix_product_nn', 'gemm_tn', 'gemm_nt', 'gemm_tt']
2015-08-27 22:56:05 -04:00
operations = [getattr(sc.templates,op) for op in operations if getattr(args, op)]
return (device, operations, args.json)
class ProgressBar:
def __init__(self, length, metric_name):
self.length = length
self.metric_name = metric_name
def set_prefix(self, prefix):
self.prefix = prefix
2015-08-28 14:36:09 -04:00
sys.stdout.write("{0}: [{1}] {2: >3}%".format(prefix.ljust(17), ' '*self.length, 0))
sys.stdout.flush()
2015-08-28 12:16:22 -04:00
def set_finished(self):
sys.stdout.write("\n")
2015-08-27 22:56:05 -04:00
def update(self, i, total, x, y, complete=False):
percent = float(i) / total
hashes = '#' * int(round(percent * self.length))
spaces = ' ' * (self.length - len(hashes))
2015-08-27 22:56:05 -04:00
#Format of structures to print
xformat = ','.join(map(str,map(int, x)))
yformat = int(y)
percentformat = int(round(percent * 100))
2015-08-28 14:36:09 -04:00
sys.stdout.write(("\r{0}: [{1}] {2: >3}% [{3} {4}] ({5})").format(self.prefix.ljust(17), hashes + spaces, percentformat, yformat, self.metric_name, xformat))
sys.stdout.flush()
if __name__ == "__main__":
2015-12-16 16:34:36 -05:00
logelementwise_2d = logging.getLogger(__name__)
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(logging.Formatter('%(message)s'))
sh.setLevel(logging.INFO)
2015-12-16 16:34:36 -05:00
logelementwise_2d.addHandler(sh)
logelementwise_2d.setLevel(logging.INFO)
sc.driver.default.queue_properties = sc.driver.PROFILING_ENABLE
2015-08-27 22:56:05 -04:00
device, operations, json = parse_arguments()
for operation in operations:
2015-12-16 16:34:36 -05:00
tuner = Tuner(logelementwise_2d, device, operation, json, ProgressBar(30, metric_name_of(operation)))
2015-11-20 22:46:52 -05:00
tuner.run(level='intermediate')