Files
baselines/baselines/her/her.py
pzhokhov b875fb7b5e release Internal changes (#800)
* joshim5 changes (width and height to WarpFrame wrapper)

* match network output with action distribution via a linear layer only if necessary (#167)

* support color vs. grayscale option in WarpFrame wrapper (#166)

* support color vs. grayscale option in WarpFrame wrapper

* Support color in other wrappers

* Updated per Peters suggestions

* fixing test failures

* ppo2 with microbatches (#168)

* pass microbatch_size to the model during construction

* microbatch fixes and test (#169)

* microbatch fixes and test

* tiny cleanup

* added assertions to the test

* vpg-related fix

* Peterz joshim5 subclass ppo2 model (#170)

* microbatch fixes and test

* tiny cleanup

* added assertions to the test

* vpg-related fix

* subclassing the model to make microbatched version of model WIP

* made microbatched model a subclass of ppo2 Model

* flake8 complaint

* mpi-less ppo2 (resolving merge conflict)

* flake8 and mpi4py imports in ppo2/model.py

* more un-mpying

* merge master

* updates to the benchmark viewer code + autopep8 (#184)

* viz docs and syntactic sugar wip

* update viewer yaml to use persistent volume claims

* move plot_util to baselines.common, update links

* use 1Tb hard drive for results viewer

* small updates to benchmark vizualizer code

* autopep8

* autopep8

* any folder can be a benchmark

* massage games image a little bit

* fixed --preload option in app.py

* remove preload from run_viewer.sh

* remove pdb breakpoints

* update bench-viewer.yaml

* fixed bug (#185)

* fixed bug 

it's wrong to do the else statement, because no other nodes would start.

* changed the fix slightly

* Refactor her phase 1 (#194)

* add monitor to the rollout envs in her RUN BENCHMARKS her

* Slice -> Slide in her benchmarks RUN BENCHMARKS her

* run her benchmark for 200 epochs

* dummy commit to RUN BENCHMARKS her

* her benchmark for 500 epochs RUN BENCHMARKS her

* add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her

* add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her

* add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her

* disable saving of policies in her benchmark RUN BENCHMARKS her

* run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch

* run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch

* launcher refactor wip

* wip

* her works on FetchReach

* her runner refactor RUN BENCHMARKS Fetch1M

* unit test for her

* fixing warnings in mpi_average in her, skip test_fetchreach if mujoco is not present

* pickle-based serialization in her

* remove extra import from subproc_vec_env.py

* investigating differences in rollout.py

* try with old rollout code RUN BENCHMARKS her

* temporarily use DummyVecEnv in cmd_util.py RUN BENCHMARKS her

* dummy commit to RUN BENCHMARKS her

* set info_values in rollout worker in her RUN BENCHMARKS her

* bug in rollout_new.py RUN BENCHMARKS her

* fixed bug in rollout_new.py RUN BENCHMARKS her

* do not use last step because vecenv calls reset and returns obs after reset RUN BENCHMARKS her

* updated buffer sizes RUN BENCHMARKS her

* fixed loading/saving via joblib

* dust off learning from demonstrations in HER, docs, refactor

* add deprecation notice on her play and plot files

* address comments by Matthias

* 1.5 months of codegen changes (#196)

* play with resnet

* feed_dict version

* coinrun prob and more stats

* fixes to get_choices_specs & hp search

* minor prob fixes

* minor fixes

* minor

* alternative version of rl_algo stuff

* pylint fixes

* fix bugs, move node_filters to soup

* changed how get_algo works

* change how get_algo works, probably broke all tests

* continue previous refactor

* get eval_agent running again

* fixing tests

* fix tests

* fix more tests

* clean up cma stuff

* fix experiment

* minor changes to eval_agent to make ppo_metal use gpu

* make dict space work

* modify mac makefile to use conda

* recurrent layers

* play with bn and resnets

* minor hp changes

* minor

* got rid of use_fb argument and jtft (joint-train-fine-tune) functionality
built test phase directly into AlgoProb

* make new rl algos generateable

* pylint; start fixing tests

* fixing tests

* more test fixes

* pylint

* fix search

* work on search

* hack around infinite loop caused by scan

* algo search fixes

* misc changes for search expt

* enable annealing, overriding options of Op

* pylint fixes

* identity op

* achieve use_last_output through masking so it automatically works in other distributions

* fix tests

* minor

* discrete

* use_last_output to be just a preference, not a hard constraint

* pred delay, pruning

* require nontrivial inputs

* aliases for get_sm

* add probname to probs

* fixes

* small fixes

* fix tests

* fix tests

* fix tests

* minor

* test scripts

* dualgru network improvements

* minor

* work on mysterious bugs

* rcall gpu-usage command for kube

* use cache dir that’s not in code folder, so that it doesn’t get removed by rcall code rsync

* add power mode to gpu usage

* make sure train/test actually different

* remove VR for now

* minor fixes

* simplify soln_db

* minor

* big refactor of mpi eda

* improve mpieda for multitask

* - get rid of timelimit hack
- add __del__ to cleanup SubprocVecEnv

* get multitask working better

* fixes

* working on atari, various

* annotate ops with whether they’re parametrized

* minor

* gym version

* rand atari prob

* minor

* SolnDb bugfix and name change

* pyspy script

* switch conv layers

* fix roboschool/bullet3

* nenvs assertion

* fix rand atari

* get rid of blanket exception catching
fix soln_db bug

* fix rand_atari

* dynamic routing as cmdline arg

* slight modifications to test_mpi_map and pyspy-all

* max_tries argument for run_until_successs

* dedup option in train_mle

* simplify soln_db

* increase atari horizon for 1 experiment

* start implementing reward increment

* ent multiplier

* create cc dsl
other misc fixes

* cc ops

* q_func -> qs in rl_algos_cc.py

* fix PredictDistr

* rl_ops_cc fixes, MakeAction op

* augment algo agent to support cc stuff

* work on ddpg experiments

* fix blocking
temporarily change logger

* allow layer scaling

* pylint fixes

* spawn_method

* isolate ddpg hacks

* improve pruning

* use spawn for subproc

* remove use of python -c in rcall

* fix pylint warning

* fix static

* maybe fix local backend

* switch to DummyVecEnv

* making some fixes via pylint

* pylint fixes

* fixing tests

* fix tests

* fix tests

* write scaffolding for SSL in Codegen

* logger fix

* fix error

* add EMA op to sl_ops

* save many changes

* save

* add upsampler

* add sl ops, enhance state machine

* get ssl search working — some gross hacking

* fix session/graph issue

* fix importing

* work on mle

* - scale embeddings in gru model
- better exception handling in sl_prob
- use emas for test/val
- use non-contrib batch_norm layer

* improve logging

* option to average before dumping in logger

* default arguments, etc

* new ddpg and identity test

* concat fix

* minor

* move realistic ssl stuff to third-party (underscore to dash)

* fixes

* remove realistic_ssl_evaluation

* pylint fixes

* use gym master

* try again

* pass around args without gin

* fix tests

* separate line to install gym

* rename failing tests that should be ignored

* add data aug

* ssl improvements

* use fixed time limit

* try to fix baselines tests

* add score_floor, max_walltime, fiddle with lr decay

* realistic_ssl

* autopep8

* various ssl
- enable blocking grad for simplification
- kl
- multiple final prediction

* fix pruning

* misc ssl stuff

* bring back linear schedule, don’t use allgather for collecting stats
(i’ve been getting nondeterministic errors from the old code)

* save/load weights in SSL, big stepsize

* cleanup SslProb

* fix

* get rid of kl coef

* fix simplification, lower lr

* search over hps

* minor fixes

* minor

* static analysis

* move files and rename things for improved consistency.
still broken, and just saving before making nontrivial changes

* various

* make tests pass

* move coinrun_train to codegen since it depends on codegen

* fixes

* pylint fixes

* improve tests
fix some things

* improve tests

* lint

* fix up db_info.py, tests

* mostly restore master version of envs directory, except for makefile changes

* fix tests

* improve printing

* minor fixes

* fix fixmes

* pruning test

* fixes

* lint

* write new test that makes tf graphs of random algos; fix some bugs it caught

* add —delete flag to rcall upload-code command

* lint

* get cifar10 lazily for testing purposes

* disable codegen ci tests for now

* clean up rl_ops

* rename spec classes

* td3 with identity test

* identity tests without gin files

* remove gin.configurable from AlgoAgent

* comments about reduction in rl_ops_cc

* address @pzhokhov comments

* fix tests

* more linting

* better tests

* clean up filtering a bit

* fix concat

* delayed logger configuration (#208)

* delayed logger configuration

* fix typo

* setters and getters for Logger.DEFAULT as well

* do away with fancy property stuff - unable to get it to work with class level methods

* grammar and spaces

* spaces

* use get_current function instead of reading Logger.CURRENT

* autopep8

* disable mpi in subprocesses (#213)

* lazy_mpi load

* cleanups

* more lazy mpi

* don't pretend that class is a module, just use it as a class

* mass-replace mpi4py imports

* flake8

* fix previous lazy_mpi imports

* silly recursion

* try os.environ hack

* better prefix test, work with mpich

* restored MPI imports

* removed commented import in test_with_mpi

* restored codegen from master

* remove lazy mpi

* restored changes from rl-algs

* remove extra files

* address Chris' comments

* use spawn for shmem vec env as well (#2) (#219)

* lazy_mpi load

* cleanups

* more lazy mpi

* don't pretend that class is a module, just use it as a class

* mass-replace mpi4py imports

* flake8

* fix previous lazy_mpi imports

* silly recursion

* try os.environ hack

* better prefix test, work with mpich

* restored MPI imports

* removed commented import in test_with_mpi

* restored codegen from master

* remove lazy mpi

* restored changes from rl-algs

* remove extra files

* port mpi fix to shmem vec env

* increase the mpi test default timeout

* change humanoid hyperparameters, get rid of clip_Frac annealing, as it's apparently dangerous

* remove clip_frac schedule from ppo2

* more timesteps in humanoid run

* whitespace + RUN BENCHMARKS

* baselines: export vecenvs from folder (#221)

* baselines: export vecenvs from folder

* put missing function back in

* add missing imports

* more imports

* longer mpi timeout?

* make default logger configuration the same as call to logger.configure() (#222)

* Vecenv refactor (#223)

* update karl util

* restore pvi flag

* change rcall auto cpu behavior, move gin.configurable, add os.makedirs

* vecenv refactor

* aux buf index fix

* add num aux obs

* reset level with enter

* restore high difficulty flag

* bugfix

* restore train_coinrun.py

* tweaks

* renaming

* renaming

* better arguments handling

* more options

* options cleanup

* game data refactor

* more options

* args for train_procgen

* add close handler to interactive base class

* use debug build if debug=True, fix range on aux_obs

* add ProcGenEnv to __init__.py, add missing imports to procgen.py

* export RemoveDictWrapper and build, update train_procgen.py, move assets download into env creation and replace init_assets_and_build with just build

* fix formatting issues

* only call global init once

* fix path in setup.py

* revert part of makefile

* ignore IDE files and folders

* vec remove dict

* export VecRemoveDictObs

* remove RemoveDictWrapper

* remove IDE files

* move shared .h and .cpp files to common folder, update build to use those, dedupe env.cpp

* fix missing header

* try unified build function

* remove old scripts dir

* add comment on build

* upload libenv with render fixes

* tell qthreads to die when we unload the library

* pyglet.app.run is garbage

* static fixes

* whoops

* actually vsync is on

* cleanup

* cleanup

* extern C for libenv interface

* parse util rcall arg

* high difficulty fix

* game type enums

* ProcGenEnv subclasses

* game type cleanup

* unrecognized key

* unrecognized game type

* parse util reorg

* args management

* typo fix

* GinParser

* arg tweaks

* tweak

* restore start_level/num_levels setting

* fix create_procgen_env interface

* build fix

* procgen args in init signature

* fix

* build fix

* fix logger usage in ppo_metal/run_retro

* removed unnecessary OrderedDict requirement in subproc_vec_env

* flake8 fix

* allow for non-mpi tests

* mpi test fixes

* flake8; removed special logic for discrete spaces in dummy_vec_env

* remove forked argument in front of tests - does not play nicely with subprocvecenv in spawned processes; analog of forked in ddpg/test_smoke

* Everyrl initial commit & a few minor baselines changes (#226)

* everyrl initial commit

* add keep_buf argument to VecMonitor

* logger changes: set_comm and fix to mpi_mean functionality

* if filename not provided, don't create ResultsWriter

* change variable syncing function to simplify its usage. now you should initialize from all mpi processes

* everyrl coinrun changes

* tf_distr changes, bugfix

* get_one

* bring back get_next to temporarily restore code

* lint fixes

* fix test

* rename profile function

* rename gaussian

* fix coinrun training script

* change random seeding to work with new gym version (#231)

* change random seeding to work with new gym version

* move seeding to seed() method

* fix mnistenv

* actually try some of the tests before pushing

* more deterministic fixed seq

* misc changes to vecenvs and run.py for benchmarks (#236)

* misc changes to vecenvs and run.py for benchmarks

* dont seed global gen

* update more references to assert_venvs_equal

* Rl19 (#232)

* everyrl initial commit

* add keep_buf argument to VecMonitor

* logger changes: set_comm and fix to mpi_mean functionality

* if filename not provided, don't create ResultsWriter

* change variable syncing function to simplify its usage. now you should initialize from all mpi processes

* everyrl coinrun changes

* tf_distr changes, bugfix

* get_one

* bring back get_next to temporarily restore code

* lint fixes

* fix test

* rename profile function

* rename gaussian

* fix coinrun training script

* rl19

* remove everyrl dir which appeared in the merge for some reason

* readme

* fiddle with ddpg

* make ddpg work

* steps_total argument

* gpu count

* clean up hyperparams and shape math

* logging + saving

* configuration stuff

* fixes, smoke tests

* fix stats

* make load_results return dicts -- easier to create the same kind of objects with some other mechanism for passing to downstream functions

* benchmarks

* fix tests

* add dqn to tests, fix it

* minor

* turned annotated transformer (pytorch) into a script

* more refactoring

* jax stuff

* cluster

* minor

* copy & paste alec code

* sign error

* add huber, rename some parameters, snapshotting off by default

* remove jax stuff

* minor

* move maze env

* minor

* remove trailing spaces

* remove trailing space

* lint

* fix test breakage due to gym update

* rename function

* move maze back to codegen

* get recurrent ppo working

* enable both lstm and gru

* script to print table of benchmark results

* various

* fix dqn

* add fixup initializer, remove lastrew

* organize logging stats

* fix silly bug

* refactor models

* fix mpi usage

* check sync

* minor

* change vf coef, hps

* clean up slicing in ppo

* minor fixes

* caching transformer

* docstrings

* xf fixes

* get rid of 'B' and 'BT' arguments

* minor

* transformer example

* remove output_kind from base class until we have a better idea how to use it

* add comments, revert maze stuff

* flake8

* codegen lint

* fix codegen tests

* responded to peter's comments

* lint fixes

* minor changes to baselines (#243)

* minor changes to baselines

* fix spaces reference

* remove flake8 disable comments and fix import

* okay maybe don't add spec to vec_env

* Merge branch 'master' of github.com:openai/games

 the commit.

* flake8 complaints in baselines/her
2019-02-27 15:35:31 -08:00

194 lines
7.3 KiB
Python

import os
import click
import numpy as np
import json
from mpi4py import MPI
from baselines import logger
from baselines.common import set_global_seeds, tf_util
from baselines.common.mpi_moments import mpi_moments
import baselines.her.experiment.config as config
from baselines.her.rollout import RolloutWorker
def mpi_average(value):
if not isinstance(value, list):
value = [value]
if not any(value):
value = [0.]
return mpi_moments(np.array(value))[0]
def train(*, policy, rollout_worker, evaluator,
n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval,
save_path, demo_file, **kwargs):
rank = MPI.COMM_WORLD.Get_rank()
if save_path:
latest_policy_path = os.path.join(save_path, 'policy_latest.pkl')
best_policy_path = os.path.join(save_path, 'policy_best.pkl')
periodic_policy_path = os.path.join(save_path, 'policy_{}.pkl')
logger.info("Training...")
best_success_rate = -1
if policy.bc_loss == 1: policy.init_demo_buffer(demo_file) #initialize demo buffer if training with demonstrations
# num_timesteps = n_epochs * n_cycles * rollout_length * number of rollout workers
for epoch in range(n_epochs):
# train
rollout_worker.clear_history()
for _ in range(n_cycles):
episode = rollout_worker.generate_rollouts()
policy.store_episode(episode)
for _ in range(n_batches):
policy.train()
policy.update_target_net()
# test
evaluator.clear_history()
for _ in range(n_test_rollouts):
evaluator.generate_rollouts()
# record logs
logger.record_tabular('epoch', epoch)
for key, val in evaluator.logs('test'):
logger.record_tabular(key, mpi_average(val))
for key, val in rollout_worker.logs('train'):
logger.record_tabular(key, mpi_average(val))
for key, val in policy.logs():
logger.record_tabular(key, mpi_average(val))
if rank == 0:
logger.dump_tabular()
# save the policy if it's better than the previous ones
success_rate = mpi_average(evaluator.current_success_rate())
if rank == 0 and success_rate >= best_success_rate and save_path:
best_success_rate = success_rate
logger.info('New best success rate: {}. Saving policy to {} ...'.format(best_success_rate, best_policy_path))
evaluator.save_policy(best_policy_path)
evaluator.save_policy(latest_policy_path)
if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_path:
policy_path = periodic_policy_path.format(epoch)
logger.info('Saving periodic policy to {} ...'.format(policy_path))
evaluator.save_policy(policy_path)
# make sure that different threads have different seeds
local_uniform = np.random.uniform(size=(1,))
root_uniform = local_uniform.copy()
MPI.COMM_WORLD.Bcast(root_uniform, root=0)
if rank != 0:
assert local_uniform[0] != root_uniform[0]
return policy
def learn(*, network, env, total_timesteps,
seed=None,
eval_env=None,
replay_strategy='future',
policy_save_interval=5,
clip_return=True,
demo_file=None,
override_params=None,
load_path=None,
save_path=None,
**kwargs
):
override_params = override_params or {}
if MPI is not None:
rank = MPI.COMM_WORLD.Get_rank()
num_cpu = MPI.COMM_WORLD.Get_size()
# Seed everything.
rank_seed = seed + 1000000 * rank if seed is not None else None
set_global_seeds(rank_seed)
# Prepare params.
params = config.DEFAULT_PARAMS
env_name = env.spec.id
params['env_name'] = env_name
params['replay_strategy'] = replay_strategy
if env_name in config.DEFAULT_ENV_PARAMS:
params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
params.update(**override_params) # makes it possible to override any parameter
with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f:
json.dump(params, f)
params = config.prepare_params(params)
params['rollout_batch_size'] = env.num_envs
if demo_file is not None:
params['bc_loss'] = 1
params.update(kwargs)
config.log_params(params, logger=logger)
if num_cpu == 1:
logger.warn()
logger.warn('*** Warning ***')
logger.warn(
'You are running HER with just a single MPI worker. This will work, but the ' +
'experiments that we report in Plappert et al. (2018, https://arxiv.org/abs/1802.09464) ' +
'were obtained with --num_cpu 19. This makes a significant difference and if you ' +
'are looking to reproduce those results, be aware of this. Please also refer to ' +
'https://github.com/openai/baselines/issues/314 for further details.')
logger.warn('****************')
logger.warn()
dims = config.configure_dims(params)
policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)
if load_path is not None:
tf_util.load_variables(load_path)
rollout_params = {
'exploit': False,
'use_target_net': False,
'use_demo_states': True,
'compute_Q': False,
'T': params['T'],
}
eval_params = {
'exploit': True,
'use_target_net': params['test_with_polyak'],
'use_demo_states': False,
'compute_Q': True,
'T': params['T'],
}
for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
rollout_params[name] = params[name]
eval_params[name] = params[name]
eval_env = eval_env or env
rollout_worker = RolloutWorker(env, policy, dims, logger, monitor=True, **rollout_params)
evaluator = RolloutWorker(eval_env, policy, dims, logger, **eval_params)
n_cycles = params['n_cycles']
n_epochs = total_timesteps // n_cycles // rollout_worker.T // rollout_worker.rollout_batch_size
return train(
save_path=save_path, policy=policy, rollout_worker=rollout_worker,
evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'],
n_cycles=params['n_cycles'], n_batches=params['n_batches'],
policy_save_interval=policy_save_interval, demo_file=demo_file)
@click.command()
@click.option('--env', type=str, default='FetchReach-v1', help='the name of the OpenAI Gym environment that you want to train on')
@click.option('--total_timesteps', type=int, default=int(5e5), help='the number of timesteps to run')
@click.option('--seed', type=int, default=0, help='the random seed used to seed both the environment and the training code')
@click.option('--policy_save_interval', type=int, default=5, help='the interval with which policy pickles are saved. If set to 0, only the best and latest policy will be pickled.')
@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future', help='the HER replay strategy to be used. "future" uses HER, "none" disables HER.')
@click.option('--clip_return', type=int, default=1, help='whether or not returns should be clipped')
@click.option('--demo_file', type=str, default = 'PATH/TO/DEMO/DATA/FILE.npz', help='demo data file path')
def main(**kwargs):
learn(**kwargs)
if __name__ == '__main__':
main()