diff --git a/baselines/run.py b/baselines/run.py index 4aaf1a7..28cf620 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -131,7 +131,7 @@ def get_env_type(env_id): def get_default_network(env_type): - if env_type == 'atari': + if env_type in {'atari', 'retro'}: return 'cnn' else: return 'mlp'