2014-10-13 03:38:19 +02:00
|
|
|
from sklearn import tree
|
|
|
|
from sklearn import ensemble
|
2014-10-27 03:28:46 -04:00
|
|
|
import numpy as np
|
2014-10-13 03:38:19 +02:00
|
|
|
|
2014-10-16 17:49:17 -04:00
|
|
|
def gmean(a, axis=0, dtype=None):
|
|
|
|
if not isinstance(a, np.ndarray): # if not an ndarray object attempt to convert it
|
|
|
|
log_a = np.log(np.array(a, dtype=dtype))
|
|
|
|
elif dtype: # Must change the default dtype allowing array type
|
|
|
|
if isinstance(a,np.ma.MaskedArray):
|
|
|
|
log_a = np.log(np.ma.asarray(a, dtype=dtype))
|
|
|
|
else:
|
|
|
|
log_a = np.log(np.asarray(a, dtype=dtype))
|
|
|
|
else:
|
|
|
|
log_a = np.log(a)
|
|
|
|
return np.exp(log_a.mean(axis=axis))
|
|
|
|
|
2014-10-04 08:58:11 +02:00
|
|
|
def train_model(X, Y, profiles, metric):
|
2014-10-29 17:01:57 +01:00
|
|
|
#Shuffle
|
|
|
|
p = np.random.permutation(X.shape[0])
|
|
|
|
X = X[p,:]
|
|
|
|
Y = Y[p,:]
|
|
|
|
#Normalize
|
2014-10-27 03:28:46 -04:00
|
|
|
Ymax = np.max(Y)
|
|
|
|
Y = Y/Ymax
|
|
|
|
|
|
|
|
#Train the model
|
|
|
|
cut = int(0.75*X.shape[0])
|
2014-10-29 17:01:57 +01:00
|
|
|
clf = ensemble.RandomForestRegressor(10, max_depth=3).fit(X[:cut,:], Y[:cut,:])
|
2014-10-03 09:29:45 +02:00
|
|
|
|
2014-10-27 03:28:46 -04:00
|
|
|
t = np.argmin(clf.predict(X[cut:,:]), axis = 1)
|
|
|
|
s = np.array([y[0]/y[k] for y,k in zip(Y[cut:,:], t)])
|
|
|
|
tt = np.argmin(Y[cut:,:], axis = 1)
|
|
|
|
ss = np.array([y[0]/y[k] for y,k in zip(Y[cut:,:], tt)])
|
2014-09-28 19:37:56 -04:00
|
|
|
|
2014-10-27 03:28:46 -04:00
|
|
|
p5 = lambda a: np.percentile(a, 5)
|
|
|
|
p25 = lambda a: np.percentile(a, 25)
|
|
|
|
p50 = lambda a: np.percentile(a, 50)
|
|
|
|
p75 = lambda a: np.percentile(a, 75)
|
|
|
|
p95 = lambda a: np.percentile(a, 95)
|
2014-09-29 03:01:33 +02:00
|
|
|
|
2014-10-27 03:28:46 -04:00
|
|
|
print("Percentile :\t 5 \t 25 \t 50 \t 75 \t 95")
|
|
|
|
print("Testing speedup:\t %.2f\t %.2f\t %.2f\t %.2f\t %.3f"%(p5(s), p25(s), p50(s), p75(s), p95(s)))
|
|
|
|
print("Optimal speedup:\t %.2f\t %.2f\t %.2f\t %.2f\t %.3f"%(p5(ss), p25(ss), p50(ss), p75(ss), p95(ss)))
|
2014-10-13 03:38:19 +02:00
|
|
|
|
2014-10-27 03:28:46 -04:00
|
|
|
return clf
|