mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-01 10:27:43 +00:00
Feature to architecture and training params func
This commit is contained in:
@@ -293,3 +293,108 @@ class CNNClassifierTraining(gym.Env):
|
|||||||
self.epoch_idx = 0
|
self.epoch_idx = 0
|
||||||
|
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
|
def train_blueprint(self, lr, decay, momentum, batch_size, l1, l2, convs, fcs):
|
||||||
|
|
||||||
|
X, Y, Xv, Yv = self.data
|
||||||
|
nb_classes = self.nb_classes
|
||||||
|
|
||||||
|
reg = WeightRegularizer()
|
||||||
|
|
||||||
|
# a hack to make regularization variable
|
||||||
|
reg.l1 = K.variable(0.0)
|
||||||
|
reg.l2 = K.variable(0.0)
|
||||||
|
|
||||||
|
# input square image dimensions
|
||||||
|
img_rows, img_cols = X.shape[-1], X.shape[-1]
|
||||||
|
img_channels = X.shape[1]
|
||||||
|
|
||||||
|
# convert class vectors to binary class matrices
|
||||||
|
Y = np_utils.to_categorical(Y, nb_classes)
|
||||||
|
Yv = np_utils.to_categorical(Yv, nb_classes)
|
||||||
|
|
||||||
|
# here definition of the model happens
|
||||||
|
model = Sequential()
|
||||||
|
|
||||||
|
has_convs = False
|
||||||
|
# create all convolutional layers
|
||||||
|
for val, use in convs.reshape((5, 2)):
|
||||||
|
|
||||||
|
# Size of convolutional layer
|
||||||
|
cnvSz = int(val * 128)+1
|
||||||
|
|
||||||
|
if use < 0.5:
|
||||||
|
continue
|
||||||
|
has_convs = True
|
||||||
|
model.add(Convolution2D(cnvSz, 3, 3, border_mode='same',
|
||||||
|
input_shape=(img_channels, img_rows, img_cols),
|
||||||
|
W_regularizer=reg,
|
||||||
|
b_regularizer=reg))
|
||||||
|
model.add(Activation('relu'))
|
||||||
|
|
||||||
|
model.add(MaxPooling2D(pool_size=(2, 2)))
|
||||||
|
# model.add(Dropout(0.25))
|
||||||
|
|
||||||
|
if has_convs:
|
||||||
|
model.add(Flatten())
|
||||||
|
else:
|
||||||
|
model.add(Flatten( input_shape=(img_channels, img_rows, img_cols) )) # avoid excetpions on no convs
|
||||||
|
|
||||||
|
# create all fully connected layers
|
||||||
|
for val, use in fcs.reshape((2, 2)):
|
||||||
|
|
||||||
|
if use < 0.5:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# choose fully connected layer size
|
||||||
|
densesz = int(1024 * val)+1
|
||||||
|
|
||||||
|
model.add(Dense(densesz,
|
||||||
|
W_regularizer=reg,
|
||||||
|
b_regularizer=reg))
|
||||||
|
model.add(Activation('relu'))
|
||||||
|
# model.add(Dropout(0.5))
|
||||||
|
|
||||||
|
model.add(Dense(nb_classes,
|
||||||
|
W_regularizer=reg,
|
||||||
|
b_regularizer=reg))
|
||||||
|
model.add(Activation('softmax'))
|
||||||
|
|
||||||
|
# let's train the model using SGD + momentum (how original).
|
||||||
|
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
|
||||||
|
model.compile(loss='categorical_crossentropy',
|
||||||
|
optimizer=sgd,
|
||||||
|
metrics=['accuracy'])
|
||||||
|
|
||||||
|
X = X.astype('float32')
|
||||||
|
Xv = Xv.astype('float32')
|
||||||
|
X /= 255
|
||||||
|
Xv /= 255
|
||||||
|
|
||||||
|
model = model
|
||||||
|
sgd = sgd
|
||||||
|
reg = reg
|
||||||
|
|
||||||
|
# set parameters of training step
|
||||||
|
|
||||||
|
sgd.lr.set_value(lr)
|
||||||
|
sgd.decay.set_value(decay)
|
||||||
|
sgd.momentum.set_value(momentum)
|
||||||
|
|
||||||
|
reg.l1.set_value(l1)
|
||||||
|
reg.l2.set_value(l2)
|
||||||
|
|
||||||
|
# train model for one epoch_idx
|
||||||
|
H = model.fit(X, Y,
|
||||||
|
batch_size=int(batch_size),
|
||||||
|
nb_epoch=10,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
|
|
||||||
|
diverged = math.isnan(H.history['loss'][-1])
|
||||||
|
acc = 0.0
|
||||||
|
|
||||||
|
if not diverged:
|
||||||
|
_, acc = model.evaluate(Xv, Yv)
|
||||||
|
|
||||||
|
return diverged, acc
|
Reference in New Issue
Block a user