diff --git a/README.md b/README.md index adca3d7..c7d4483 100644 --- a/README.md +++ b/README.md @@ -48,15 +48,15 @@ The master branch supports Tensorflow from version 1.4 to 1.14. For Tensorflow 2 git clone https://github.com/openai/baselines.git cd baselines ``` -- If you don't have TensorFlow installed already, install your favourite flavor of TensorFlow. In most cases, +- If you don't have TensorFlow installed already, install your favourite flavor of TensorFlow. In most cases, you may use ```bash - pip install tensorflow-gpu # if you have a CUDA-compatible gpu and proper drivers + pip install tensorflow-gpu==1.14 # if you have a CUDA-compatible gpu and proper drivers ``` or ```bash - pip install tensorflow + pip install tensorflow==1.14 ``` - should be sufficient. Refer to [TensorFlow installation guide](https://www.tensorflow.org/install/) + to install Tensorflow 1.14, which is the latest version of Tensorflow supported by the master branch. Refer to [TensorFlow installation guide](https://www.tensorflow.org/install/) for more details. - Install baselines package diff --git a/baselines/run.py b/baselines/run.py index e1b11a5..33bb15f 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -226,7 +226,7 @@ def main(args): state = model.initial_state if hasattr(model, 'initial_state') else None dones = np.zeros((1,)) - episode_rew = 0 + episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1) while True: if state is not None: actions, _, state, _ = model.step(obs,S=state, M=dones) @@ -234,13 +234,13 @@ def main(args): actions, _, _, _ = model.step(obs) obs, rew, done, _ = env.step(actions) - episode_rew += rew[0] if isinstance(env, VecEnv) else rew + episode_rew += rew env.render() - done = done.any() if isinstance(done, np.ndarray) else done - if done: - print('episode_rew={}'.format(episode_rew)) - episode_rew = 0 - obs = env.reset() + done_any = done.any() if isinstance(done, np.ndarray) else done + if done_any: + for i in np.nonzero(done)[0]: + print('episode_rew={}'.format(episode_rew[i])) + episode_rew[i] = 0 env.close()