GEMM now specified in terms of MNK rather than MKN
This commit is contained in:
@@ -42,8 +42,6 @@ def do_tuning(args, devices):
|
||||
json_out = {}
|
||||
json_out["version"] = "1.0"
|
||||
|
||||
print json_out
|
||||
|
||||
def map_to_list(T, x):
|
||||
return list(map(T, x if isinstance(x, list) else [x]))
|
||||
|
||||
@@ -162,13 +160,13 @@ def do_tuning(args, devices):
|
||||
def execution_handler(sizes, fname=os.devnull, parameters=None):
|
||||
A_trans = layout[0]
|
||||
B_trans = layout[1]
|
||||
A = vcl.Matrix((sizes[0], sizes[1]) if A_trans=='N' else (sizes[1],sizes[0]), context=ctx, dtype=datatype, layout=vcl.COL_MAJOR);
|
||||
B = vcl.Matrix((sizes[1], sizes[2]) if B_trans=='N' else (sizes[2],sizes[1]), context=ctx, dtype=datatype, layout=vcl.COL_MAJOR);
|
||||
A = vcl.Matrix((sizes[0], sizes[2]) if A_trans=='N' else (sizes[2],sizes[0]), context=ctx, dtype=datatype, layout=vcl.COL_MAJOR);
|
||||
B = vcl.Matrix((sizes[2], sizes[1]) if B_trans=='N' else (sizes[1],sizes[2]), context=ctx, dtype=datatype, layout=vcl.COL_MAJOR);
|
||||
LHS = A if A_trans=='N' else A.T
|
||||
RHS = B if B_trans=='N' else B.T
|
||||
alpha = vcl.HostScalar(1.0, context=ctx, dtype = datatype)
|
||||
beta = vcl.HostScalar(1.0, context=ctx, dtype = datatype)
|
||||
C = vcl.Matrix((sizes[0], sizes[2]), context=ctx, dtype = datatype, layout=vcl.COL_MAJOR)
|
||||
C = vcl.Matrix((sizes[0], sizes[1]), context=ctx, dtype = datatype, layout=vcl.COL_MAJOR)
|
||||
return execute(device, vcl.Assign(C,LHS*RHS*alpha + C*beta),(A_trans, B_trans), sizes, fname, parameters)
|
||||
tune(execution_handler, 100, 4000, 3,(layout[0], layout[1]))
|
||||
|
||||
|
Reference in New Issue
Block a user