Code quality: more renaming
This commit is contained in:
28
tune/main.py
28
tune/main.py
@@ -10,15 +10,15 @@ def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-d", "--device", default=0, type=int, help='Device to tune for')
|
||||
parser.add_argument("-j", "--json", default='', type=str)
|
||||
parser.add_argument('--axpy', action='store_true', help='Tune AXPY')
|
||||
parser.add_argument('--dot', action='store_true', help='Tune DOT')
|
||||
parser.add_argument('--ger', action='store_true', help='Tune GER')
|
||||
parser.add_argument('--gemv_n', action='store_true', help='Tune GEMV-N')
|
||||
parser.add_argument('--gemv_t', action='store_true', help='Tune GEMV-T')
|
||||
parser.add_argument('--gemm_nn', action='store_true', help='Tune GEMM-NN')
|
||||
parser.add_argument('--gemm_tn', action='store_true', help='Tune GEMM-TN')
|
||||
parser.add_argument('--gemm_nt', action='store_true', help='Tune GEMM-NT')
|
||||
parser.add_argument('--gemm_tt', action='store_true', help='Tune GEMM-TT')
|
||||
parser.add_argument('--elementwise_1d', action='store_true', help='Tune AXPY')
|
||||
parser.add_argument('--reduce_1d', action='store_true', help='Tune DOT')
|
||||
parser.add_argument('--elementwise_2d', action='store_true', help='Tune GER')
|
||||
parser.add_argument('--reduce_2d_rows', action='store_true', help='Tune GEMV-N')
|
||||
parser.add_argument('--reduce_2d_cols', action='store_true', help='Tune GEMV-T')
|
||||
parser.add_argument('--matrix_product_nn', action='store_true', help='Tune GEMM-NN')
|
||||
parser.add_argument('--matrix_product_tn', action='store_true', help='Tune GEMM-TN')
|
||||
parser.add_argument('--matrix_product_nt', action='store_true', help='Tune GEMM-NT')
|
||||
parser.add_argument('--matrix_product_tt', action='store_true', help='Tune GEMM-TT')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -31,7 +31,7 @@ def parse_arguments():
|
||||
print selected , '-', sc.driver.device_type_to_string(d.type), '-', d.name, 'on', d.platform.name
|
||||
|
||||
|
||||
operations = ['axpy', 'dot', 'ger', 'gemv_n', 'gemv_t', 'gemm_nn', 'gemm_tn', 'gemm_nt', 'gemm_tt']
|
||||
operations = ['elementwise_1d', 'reduce_1d', 'elementwise_2d', 'reduce_2d_rows', 'reduce_2d_cols', 'matrix_product_nn', 'gemm_tn', 'gemm_nt', 'gemm_tt']
|
||||
operations = [getattr(sc.templates,op) for op in operations if getattr(args, op)]
|
||||
|
||||
return (device, operations, args.json)
|
||||
@@ -63,16 +63,16 @@ class ProgressBar:
|
||||
sys.stdout.flush()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = logging.getLogger(__name__)
|
||||
logelementwise_2d = logging.getLogger(__name__)
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
sh.setFormatter(logging.Formatter('%(message)s'))
|
||||
sh.setLevel(logging.INFO)
|
||||
logger.addHandler(sh)
|
||||
logger.setLevel(logging.INFO)
|
||||
logelementwise_2d.addHandler(sh)
|
||||
logelementwise_2d.setLevel(logging.INFO)
|
||||
|
||||
sc.driver.default.queue_properties = sc.driver.PROFILING_ENABLE
|
||||
device, operations, json = parse_arguments()
|
||||
|
||||
for operation in operations:
|
||||
tuner = Tuner(logger, device, operation, json, ProgressBar(30, metric_name_of(operation)))
|
||||
tuner = Tuner(logelementwise_2d, device, operation, json, ProgressBar(30, metric_name_of(operation)))
|
||||
tuner.run(level='intermediate')
|
||||
|
Reference in New Issue
Block a user