Merge pull request #265 from 20chase/patch-1
fix logger error for trpo_mpi
This commit is contained in:
@@ -12,7 +12,10 @@ def train(env_id, num_timesteps, seed):
|
|||||||
sess.__enter__()
|
sess.__enter__()
|
||||||
|
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
if rank != 0:
|
if rank == 0:
|
||||||
|
logger.configure()
|
||||||
|
else:
|
||||||
|
logger.configure(format_strs=[])
|
||||||
logger.set_level(logger.DISABLED)
|
logger.set_level(logger.DISABLED)
|
||||||
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
|
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
|
||||||
def policy_fn(name, ob_space, ac_space):
|
def policy_fn(name, ob_space, ac_space):
|
||||||
@@ -25,9 +28,9 @@ def train(env_id, num_timesteps, seed):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = mujoco_arg_parser().parse_args()
|
args = mujoco_arg_parser().parse_args()
|
||||||
logger.configure()
|
|
||||||
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)
|
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user