diff --git a/baselines/bench/benchmarks.py b/baselines/bench/benchmarks.py index e9328b2..c9fdd14 100644 --- a/baselines/bench/benchmarks.py +++ b/baselines/bench/benchmarks.py @@ -97,6 +97,19 @@ register_benchmark({ ] }) +# Bullet +_bulletsmall = [ + 'InvertedDoublePendulum', 'InvertedPendulum', 'HalfCheetah', 'Reacher', 'Walker2D', 'Hopper', 'Ant' +] +_bulletsmall = [e + 'BulletEnv-v0' for e in _bulletsmall] + +register_benchmark({ + 'name': 'Bullet1M', + 'description': '6 mujoco-like tasks from bullet, 1M steps', + 'tasks': [{'env_id': e, 'trials': 6, 'num_timesteps': int(1e6)} for e in _bulletsmall] +}) + + # Roboschool register_benchmark({ diff --git a/baselines/bench/monitor.py b/baselines/bench/monitor.py index 8024ea0..0db473a 100644 --- a/baselines/bench/monitor.py +++ b/baselines/bench/monitor.py @@ -16,21 +16,11 @@ class Monitor(Wrapper): def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()): Wrapper.__init__(self, env=env) self.tstart = time.time() - if filename is None: - self.f = None - self.logger = None - else: - if not filename.endswith(Monitor.EXT): - if osp.isdir(filename): - filename = osp.join(filename, Monitor.EXT) - else: - filename = filename + "." + Monitor.EXT - self.f = open(filename, "wt") - self.f.write('#%s\n'%json.dumps({"t_start": self.tstart, 'env_id' : env.spec and env.spec.id})) - self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+reset_keywords+info_keywords) - self.logger.writeheader() - self.f.flush() - + self.results_writer = ResultsWriter( + filename, + header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id}, + extra_keys=reset_keywords + info_keywords + ) self.reset_keywords = reset_keywords self.info_keywords = info_keywords self.allow_early_resets = allow_early_resets @@ -43,10 +33,7 @@ class Monitor(Wrapper): self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() def reset(self, **kwargs): - if not self.allow_early_resets and not self.needs_reset: - raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, wrap your env with Monitor(env, path, allow_early_resets=True)") - self.rewards = [] - self.needs_reset = False + self.reset_state() for k in self.reset_keywords: v = kwargs.get(k) if v is None: @@ -54,10 +41,21 @@ class Monitor(Wrapper): self.current_reset_info[k] = v return self.env.reset(**kwargs) + def reset_state(self): + if not self.allow_early_resets and not self.needs_reset: + raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, wrap your env with Monitor(env, path, allow_early_resets=True)") + self.rewards = [] + self.needs_reset = False + + def step(self, action): if self.needs_reset: raise RuntimeError("Tried to step environment that needs reset") ob, rew, done, info = self.env.step(action) + self.update(ob, rew, done, info) + return (ob, rew, done, info) + + def update(self, ob, rew, done, info): self.rewards.append(rew) if done: self.needs_reset = True @@ -70,12 +68,12 @@ class Monitor(Wrapper): self.episode_lengths.append(eplen) self.episode_times.append(time.time() - self.tstart) epinfo.update(self.current_reset_info) - if self.logger: - self.logger.writerow(epinfo) - self.f.flush() - info['episode'] = epinfo + self.results_writer.write_row(epinfo) + + if isinstance(info, dict): + info['episode'] = epinfo + self.total_steps += 1 - return (ob, rew, done, info) def close(self): if self.f is not None: @@ -96,6 +94,34 @@ class Monitor(Wrapper): class LoadMonitorResultsError(Exception): pass + +class ResultsWriter(object): + def __init__(self, filename=None, header='', extra_keys=()): + self.extra_keys = extra_keys + if filename is None: + self.f = None + self.logger = None + else: + if not filename.endswith(Monitor.EXT): + if osp.isdir(filename): + filename = osp.join(filename, Monitor.EXT) + else: + filename = filename + "." + Monitor.EXT + self.f = open(filename, "wt") + if isinstance(header, dict): + header = '# {} \n'.format(json.dumps(header)) + self.f.write(header) + self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys)) + self.logger.writeheader() + self.f.flush() + + def write_row(self, epinfo): + if self.logger: + self.logger.writerow(epinfo) + self.f.flush() + + + def get_monitor_files(dir): return glob(osp.join(dir, "*" + Monitor.EXT)) diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index cb4f054..d69589c 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -121,11 +121,18 @@ def parse_unknown_args(args): Parse arguments not consumed by arg parser into a dicitonary """ retval = {} + preceded_by_key = False for arg in args: - assert arg.startswith('--') - assert '=' in arg, 'cannot parse arg {}'.format(arg) - key = arg.split('=')[0][2:] - value = arg.split('=')[1] - retval[key] = value + if arg.startswith('--'): + if '=' in arg: + key = arg.split('=')[0][2:] + value = arg.split('=')[1] + retval[key] = value + else: + key = arg[2:] + preceded_by_key = True + elif preceded_by_key: + retval[key] = arg + preceded_by_key = False return retval diff --git a/baselines/common/console_util.py b/baselines/common/console_util.py index a7e94c0..3b011c5 100644 --- a/baselines/common/console_util.py +++ b/baselines/common/console_util.py @@ -58,6 +58,9 @@ def print_cmd(cmd, dry=False): def get_git_commit(cwd=None): return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=cwd).decode('utf8') +def get_git_commit_message(cwd=None): + return subprocess.check_output(['git', 'show', '-s', '--format=%B', 'HEAD'], cwd=cwd).decode('utf8') + def ccap(cmd, dry=False, env=None, **kwargs): print_cmd(cmd, dry) if not dry: diff --git a/baselines/common/distributions.py b/baselines/common/distributions.py index 4a84035..eaddbdf 100644 --- a/baselines/common/distributions.py +++ b/baselines/common/distributions.py @@ -23,6 +23,11 @@ class Pd(object): raise NotImplementedError def logp(self, x): return - self.neglogp(x) + def get_shape(self): + return self.flatparam().shape + @property + def shape(self): + return self.get_shape() class PdType(object): """ @@ -145,10 +150,22 @@ class CategoricalPd(Pd): # return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x) # Note: we can't use sparse_softmax_cross_entropy_with_logits because # the implementation does not allow second-order derivatives... - one_hot_actions = tf.one_hot(x, self.logits.get_shape().as_list()[-1]) + if x.dtype in {tf.uint8, tf.int32, tf.int64}: + # one-hot encoding + x_shape_list = x.shape.as_list() + logits_shape_list = self.logits.get_shape().as_list()[:-1] + for xs, ls in zip(x_shape_list, logits_shape_list): + if xs is not None and ls is not None: + assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls) + + x = tf.one_hot(x, self.logits.get_shape().as_list()[-1]) + else: + # already encoded + assert x.shape.as_list() == self.logits.shape.as_list() + return tf.nn.softmax_cross_entropy_with_logits_v2( logits=self.logits, - labels=one_hot_actions) + labels=x) def kl(self, other): a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True) a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keepdims=True) @@ -216,13 +233,19 @@ class DiagGaussianPd(Pd): @classmethod def fromflat(cls, flat): return cls(flat) + def __getitem__(self, idx): + return DiagGaussianPd(self.flat[idx]) + class BernoulliPd(Pd): def __init__(self, logits): self.logits = logits self.ps = tf.sigmoid(logits) def flatparam(self): - return self.logits + return self.logit + @property + def mean(self): + return self.ps def mode(self): return tf.round(self.ps) def neglogp(self, x): diff --git a/baselines/common/vec_env/vec_monitor.py b/baselines/common/vec_env/vec_monitor.py index 0074aee..960f76c 100644 --- a/baselines/common/vec_env/vec_monitor.py +++ b/baselines/common/vec_env/vec_monitor.py @@ -1,12 +1,16 @@ from . import VecEnvWrapper +from baselines.bench.monitor import ResultsWriter import numpy as np +import time class VecMonitor(VecEnvWrapper): - def __init__(self, venv): + def __init__(self, venv, filename=None): VecEnvWrapper.__init__(self, venv) self.eprets = None self.eplens = None + self.tstart = time.time() + self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart}) def reset(self): obs = self.venv.reset() @@ -22,8 +26,12 @@ class VecMonitor(VecEnvWrapper): for (i, (done, ret, eplen, info)) in enumerate(zip(dones, self.eprets, self.eplens, infos)): info = info.copy() if done: - info['episode'] = {'r': ret, 'l': eplen} + epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} + info['episode'] = epinfo self.eprets[i] = 0 self.eplens[i] = 0 + self.results_writer.write_row(epinfo) + newinfos.append(info) + return obs, rews, dones, newinfos diff --git a/baselines/deepq/experiments/train_pong.py b/baselines/deepq/experiments/train_pong.py index b031021..8739aed 100644 --- a/baselines/deepq/experiments/train_pong.py +++ b/baselines/deepq/experiments/train_pong.py @@ -30,6 +30,5 @@ def main(): model.save('pong_model.pkl') env.close() - if __name__ == '__main__': main() diff --git a/baselines/run.py b/baselines/run.py index 3ac3d81..cc1a512 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -154,9 +154,6 @@ def get_default_network(env_type): else: return 'mlp' - raise ValueError('Unknown env_type {}'.format(env_type)) - - def get_alg_module(alg, submodule=None): submodule = submodule or alg try: @@ -182,16 +179,21 @@ def get_learn_function_defaults(alg, env_type): return kwargs -def parse(v): - ''' - convert value of a command-line arg to a python object if possible, othewise, keep as string - ''' - assert isinstance(v, str) - try: - return eval(v) - except (NameError, SyntaxError): - return v +def parse_cmdline_kwargs(args): + ''' + convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible + ''' + def parse(v): + + assert isinstance(v, str) + try: + return eval(v) + except (NameError, SyntaxError): + return v + + return {k: parse(v) for k,v in parse_unknown_args(args).items()} + def main(): @@ -199,7 +201,7 @@ def main(): arg_parser = common_arg_parser() args, unknown_args = arg_parser.parse_known_args() - extra_args = {k: parse(v) for k, v in parse_unknown_args(unknown_args).items()} + extra_args = parse_cmdline_kwargs(unknown_args) if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: rank = 0 diff --git a/setup.py b/setup.py index a9648fa..b89f777 100644 --- a/setup.py +++ b/setup.py @@ -10,10 +10,12 @@ extras = { 'test': [ 'filelock', 'pytest' + ], + 'bullet': [ + 'pybullet', ] } - all_deps = [] for group_name in extras: all_deps += extras[group_name]