diff --git a/baselines/acktr/filters.py b/baselines/common/filters.py similarity index 100% rename from baselines/acktr/filters.py rename to baselines/common/filters.py diff --git a/baselines/common/runners.py b/baselines/common/runners.py new file mode 100644 index 0000000..33b4365 --- /dev/null +++ b/baselines/common/runners.py @@ -0,0 +1,18 @@ +import numpy as np +from abc import ABC, abstractmethod + +class AbstractEnvRunner(ABC): + def __init__(self, *, env, model, nsteps): + self.env = env + self.model = model + nenv = env.num_envs + self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape + self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=model.train_model.X.dtype.name) + self.obs[:] = env.reset() + self.nsteps = nsteps + self.states = model.initial_state + self.dones = [False for _ in range(nenv)] + + @abstractmethod + def run(self): + raise NotImplementedError diff --git a/baselines/acktr/running_stat.py b/baselines/common/running_stat.py similarity index 100% rename from baselines/acktr/running_stat.py rename to baselines/common/running_stat.py