Auto-tuner: More flexibility
This commit is contained in:
@@ -47,12 +47,11 @@ def do_tuning(args, devices):
|
||||
'matrix-axpy' : args.blas2_size, 'row-wise-reduction' : args.blas2_size,
|
||||
'matrix-product': args.blas3_size}
|
||||
|
||||
for operation in ['matrix-product']:
|
||||
for operation in ['vector-axpy', 'reduction', 'matrix-axpy', 'row-wise-reduction', 'matrix-product']:
|
||||
|
||||
#Iterate through the datatypes
|
||||
for datatype in [vcl.float32, vcl.float64]:
|
||||
|
||||
if operation=='matrix-product' and datatype==vcl.float64 and args.no_dgemm:
|
||||
if any(x in args.exclude_operations.split(',') for x in [operation, operation + '-' + datatype.__name__]):
|
||||
continue
|
||||
|
||||
ctx = cl.Context([device])
|
||||
@@ -139,8 +138,7 @@ def do_tuning(args, devices):
|
||||
tune(execution_handler, log_space_gen_product(100, 4000, args.sample_size, 2), log_space_gen_product(100, 4000, 1000, 2), ())
|
||||
#Row-wise reduction
|
||||
if operation=='row-wise-reduction':
|
||||
layouts = ['N', 'T']
|
||||
for A_trans in layouts:
|
||||
for A_trans in args.gemv_layouts.split(','):
|
||||
def execution_handler(sizes, fname=os.devnull, parameters=None):
|
||||
A = vcl.Matrix(sizes if A_trans=='N' else sizes[::-1], context=ctx, dtype=datatype, layout=vcl.COL_MAJOR)
|
||||
x = vcl.Vector(sizes[1], context=ctx, dtype=datatype)
|
||||
@@ -150,8 +148,7 @@ def do_tuning(args, devices):
|
||||
tune(execution_handler, log_space_gen_product(100, 4000, args.sample_size, 2), log_space_gen_product(100, 4000, 1000, 2), (A_trans,))
|
||||
#Matrix Product
|
||||
if operation=='matrix-product':
|
||||
layouts = ['NN', 'NT', 'TN', 'TT']
|
||||
for layout in layouts:
|
||||
for layout in args.gemm_layouts.split(','):
|
||||
def execution_handler(sizes, fname=os.devnull, parameters=None):
|
||||
A_trans = layout[0]
|
||||
B_trans = layout[1]
|
||||
@@ -178,7 +175,9 @@ if __name__ == "__main__":
|
||||
print_devices_parser = subparsers.add_parser('list-devices', help='list the devices available')
|
||||
tune_parser = subparsers.add_parser('tune', help='tune using a specific configuration file')
|
||||
tune_parser.add_argument("--device", default=0, type=str)
|
||||
tune_parser.add_argument("--no-dgemm", default=True, type=bool)
|
||||
tune_parser.add_argument("--exclude-operations", default = '', type=str)
|
||||
tune_parser.add_argument("--gemm-layouts", default='NN,NT,TN,TT', type=str)
|
||||
tune_parser.add_argument("--gemv-layouts", default='N,T', type=str)
|
||||
tune_parser.add_argument("--viennacl-src-path", default='', type=str)
|
||||
|
||||
tune_subparsers = tune_parser.add_subparsers(dest='method')
|
||||
|
Reference in New Issue
Block a user