GEMM now specified in terms of MNK rather than MKN

This commit is contained in:
Philippe Tillet
2014-10-29 23:43:11 -04:00
parent a8531cac37
commit de48ccc7b1

View File

@@ -42,8 +42,6 @@ def do_tuning(args, devices):
json_out = {}
json_out["version"] = "1.0"
print json_out
def map_to_list(T, x):
return list(map(T, x if isinstance(x, list) else [x]))
@@ -162,13 +160,13 @@ def do_tuning(args, devices):
def execution_handler(sizes, fname=os.devnull, parameters=None):
A_trans = layout[0]
B_trans = layout[1]
A = vcl.Matrix((sizes[0], sizes[1]) if A_trans=='N' else (sizes[1],sizes[0]), context=ctx, dtype=datatype, layout=vcl.COL_MAJOR);
B = vcl.Matrix((sizes[1], sizes[2]) if B_trans=='N' else (sizes[2],sizes[1]), context=ctx, dtype=datatype, layout=vcl.COL_MAJOR);
A = vcl.Matrix((sizes[0], sizes[2]) if A_trans=='N' else (sizes[2],sizes[0]), context=ctx, dtype=datatype, layout=vcl.COL_MAJOR);
B = vcl.Matrix((sizes[2], sizes[1]) if B_trans=='N' else (sizes[1],sizes[2]), context=ctx, dtype=datatype, layout=vcl.COL_MAJOR);
LHS = A if A_trans=='N' else A.T
RHS = B if B_trans=='N' else B.T
alpha = vcl.HostScalar(1.0, context=ctx, dtype = datatype)
beta = vcl.HostScalar(1.0, context=ctx, dtype = datatype)
C = vcl.Matrix((sizes[0], sizes[2]), context=ctx, dtype = datatype, layout=vcl.COL_MAJOR)
C = vcl.Matrix((sizes[0], sizes[1]), context=ctx, dtype = datatype, layout=vcl.COL_MAJOR)
return execute(device, vcl.Assign(C,LHS*RHS*alpha + C*beta),(A_trans, B_trans), sizes, fname, parameters)
tune(execution_handler, 100, 4000, 3,(layout[0], layout[1]))