Auto-tuner: More flexibility

This commit is contained in:
Philippe Tillet
2014-10-29 12:45:20 -04:00
parent 086e51d291
commit b46f26e54c

View File

@@ -47,12 +47,11 @@ def do_tuning(args, devices):
'matrix-axpy' : args.blas2_size, 'row-wise-reduction' : args.blas2_size,
'matrix-product': args.blas3_size}
for operation in ['matrix-product']:
for operation in ['vector-axpy', 'reduction', 'matrix-axpy', 'row-wise-reduction', 'matrix-product']:
#Iterate through the datatypes
for datatype in [vcl.float32, vcl.float64]:
if operation=='matrix-product' and datatype==vcl.float64 and args.no_dgemm:
if any(x in args.exclude_operations.split(',') for x in [operation, operation + '-' + datatype.__name__]):
continue
ctx = cl.Context([device])
@@ -139,8 +138,7 @@ def do_tuning(args, devices):
tune(execution_handler, log_space_gen_product(100, 4000, args.sample_size, 2), log_space_gen_product(100, 4000, 1000, 2), ())
#Row-wise reduction
if operation=='row-wise-reduction':
layouts = ['N', 'T']
for A_trans in layouts:
for A_trans in args.gemv_layouts.split(','):
def execution_handler(sizes, fname=os.devnull, parameters=None):
A = vcl.Matrix(sizes if A_trans=='N' else sizes[::-1], context=ctx, dtype=datatype, layout=vcl.COL_MAJOR)
x = vcl.Vector(sizes[1], context=ctx, dtype=datatype)
@@ -150,8 +148,7 @@ def do_tuning(args, devices):
tune(execution_handler, log_space_gen_product(100, 4000, args.sample_size, 2), log_space_gen_product(100, 4000, 1000, 2), (A_trans,))
#Matrix Product
if operation=='matrix-product':
layouts = ['NN', 'NT', 'TN', 'TT']
for layout in layouts:
for layout in args.gemm_layouts.split(','):
def execution_handler(sizes, fname=os.devnull, parameters=None):
A_trans = layout[0]
B_trans = layout[1]
@@ -178,7 +175,9 @@ if __name__ == "__main__":
print_devices_parser = subparsers.add_parser('list-devices', help='list the devices available')
tune_parser = subparsers.add_parser('tune', help='tune using a specific configuration file')
tune_parser.add_argument("--device", default=0, type=str)
tune_parser.add_argument("--no-dgemm", default=True, type=bool)
tune_parser.add_argument("--exclude-operations", default = '', type=str)
tune_parser.add_argument("--gemm-layouts", default='NN,NT,TN,TT', type=str)
tune_parser.add_argument("--gemv-layouts", default='N,T', type=str)
tune_parser.add_argument("--viennacl-src-path", default='', type=str)
tune_subparsers = tune_parser.add_subparsers(dest='method')