from gym import RewardWrapper class TransformReward(RewardWrapper): r"""Transform the reward via an arbitrary function. Example:: >>> import gym >>> env = gym.make('CartPole-v1') >>> env = TransformReward(env, lambda r: 0.01*r) >>> env.reset() >>> observation, reward, done, info = env.step(env.action_space.sample()) >>> reward 0.01 Args: env (Env): environment f (callable): a function that transforms the reward """ def __init__(self, env, f): super().__init__(env) assert callable(f) self.f = f def reward(self, reward): return self.f(reward)