2022-11-18 22:25:33 +01:00
""" Base class and definitions for an alternative, functional backend for gym envs, particularly suitable for hardware accelerated and otherwise transformed environments. """
2022-11-29 23:37:53 +00:00
from __future__ import annotations
2022-11-18 22:25:33 +01:00
2022-11-29 23:37:53 +00:00
from typing import Any , Callable , Generic , TypeVar
2022-11-18 22:25:33 +01:00
import numpy as np
2023-02-12 07:49:37 -05:00
from gymnasium import Space
2022-12-04 22:24:02 +08:00
2022-11-18 22:25:33 +01:00
StateType = TypeVar ( " StateType " )
ActType = TypeVar ( " ActType " )
ObsType = TypeVar ( " ObsType " )
RewardType = TypeVar ( " RewardType " )
TerminalType = TypeVar ( " TerminalType " )
RenderStateType = TypeVar ( " RenderStateType " )
class FuncEnv (
Generic [ StateType , ObsType , ActType , RewardType , TerminalType , RenderStateType ]
) :
""" Base class (template) for functional envs.
This API is meant to be used in a stateless manner , with the environment state being passed around explicitly .
That being said , nothing here prevents users from using the environment statefully , it ' s just not recommended.
A functional env consists of the following functions ( in this case , instance methods ) :
2023-11-07 13:27:25 +00:00
* initial : returns the initial state of the POMDP
* observation : returns the observation in a given state
* transition : returns the next state after taking an action in a given state
* reward : returns the reward for a given ( state , action , next_state ) tuple
* terminal : returns whether a given state is terminal
* state_info : optional , returns a dict of info about a given state
* step_info : optional , returns a dict of info about a given ( state , action , next_state ) tuple
2022-11-18 22:25:33 +01:00
The class - based structure serves the purpose of allowing environment constants to be defined in the class ,
and then using them by name in the code itself .
For the moment , this is predominantly for internal use . This API is likely to change , but in the future
we intend to flesh it out and officially expose it to end users .
"""
2023-02-12 07:49:37 -05:00
observation_space : Space
action_space : Space
2022-11-29 23:37:53 +00:00
def __init__ ( self , options : dict [ str , Any ] | None = None ) :
2022-11-18 22:25:33 +01:00
""" Initialize the environment constants. """
self . __dict__ . update ( options or { } )
def initial ( self , rng : Any ) - > StateType :
2023-11-07 13:27:25 +00:00
""" Generates the initial state of the environment with a random number generator. """
2022-11-18 22:25:33 +01:00
raise NotImplementedError
def transition ( self , state : StateType , action : ActType , rng : Any ) - > StateType :
2023-11-07 13:27:25 +00:00
""" Updates (transitions) the state with an action and random number generator. """
2022-11-18 22:25:33 +01:00
raise NotImplementedError
2022-11-29 23:37:53 +00:00
def observation ( self , state : StateType ) - > ObsType :
2023-11-07 13:27:25 +00:00
""" Generates an observation for a given state of an environment. """
2022-11-29 23:37:53 +00:00
raise NotImplementedError
2022-11-18 22:25:33 +01:00
def reward (
self , state : StateType , action : ActType , next_state : StateType
) - > RewardType :
2023-11-07 13:27:25 +00:00
""" Computes the reward for a given transition between `state`, `action` to `next_state`. """
2022-11-18 22:25:33 +01:00
raise NotImplementedError
def terminal ( self , state : StateType ) - > TerminalType :
2023-11-07 13:27:25 +00:00
""" Returns if the state is a final terminal state. """
2022-11-18 22:25:33 +01:00
raise NotImplementedError
def state_info ( self , state : StateType ) - > dict :
""" Info dict about a single state. """
return { }
2023-11-07 13:27:25 +00:00
def transition_info (
2022-11-18 22:25:33 +01:00
self , state : StateType , action : ActType , next_state : StateType
) - > dict :
""" Info dict about a full transition. """
return { }
def transform ( self , func : Callable [ [ Callable ] , Callable ] ) :
""" Functional transformations. """
self . initial = func ( self . initial )
self . transition = func ( self . transition )
2023-11-07 13:27:25 +00:00
2022-11-18 22:25:33 +01:00
self . observation = func ( self . observation )
self . reward = func ( self . reward )
self . terminal = func ( self . terminal )
2023-11-07 13:27:25 +00:00
2022-11-18 22:25:33 +01:00
self . state_info = func ( self . state_info )
2023-11-07 13:27:25 +00:00
self . transition_info = func ( self . transition_info )
2022-11-18 22:25:33 +01:00
def render_image (
self , state : StateType , render_state : RenderStateType
2022-11-29 23:37:53 +00:00
) - > tuple [ RenderStateType , np . ndarray ] :
2022-11-18 22:25:33 +01:00
""" Show the state. """
raise NotImplementedError
def render_init ( self , * * kwargs ) - > RenderStateType :
""" Initialize the render state. """
raise NotImplementedError
def render_close ( self , render_state : RenderStateType ) :
""" Close the render state. """
raise NotImplementedError