2022-11-18 22:25:33 +01:00
""" Base class and definitions for an alternative, functional backend for gym envs, particularly suitable for hardware accelerated and otherwise transformed environments. """
2024-06-10 17:07:47 +01:00
2022-11-29 23:37:53 +00:00
from __future__ import annotations
2022-11-18 22:25:33 +01:00
2022-11-29 23:37:53 +00:00
from typing import Any , Callable , Generic , TypeVar
2022-11-18 22:25:33 +01:00
import numpy as np
2023-02-12 07:49:37 -05:00
from gymnasium import Space
2022-12-04 22:24:02 +08:00
2022-11-18 22:25:33 +01:00
StateType = TypeVar ( " StateType " )
ActType = TypeVar ( " ActType " )
ObsType = TypeVar ( " ObsType " )
RewardType = TypeVar ( " RewardType " )
TerminalType = TypeVar ( " TerminalType " )
RenderStateType = TypeVar ( " RenderStateType " )
2023-12-17 15:03:06 +01:00
Params = TypeVar ( " Params " )
2022-11-18 22:25:33 +01:00
class FuncEnv (
2023-12-17 15:03:06 +01:00
Generic [
StateType , ObsType , ActType , RewardType , TerminalType , RenderStateType , Params
]
2022-11-18 22:25:33 +01:00
) :
""" Base class (template) for functional envs.
This API is meant to be used in a stateless manner , with the environment state being passed around explicitly .
That being said , nothing here prevents users from using the environment statefully , it ' s just not recommended.
A functional env consists of the following functions ( in this case , instance methods ) :
2023-11-07 13:27:25 +00:00
* initial : returns the initial state of the POMDP
* observation : returns the observation in a given state
* transition : returns the next state after taking an action in a given state
* reward : returns the reward for a given ( state , action , next_state ) tuple
* terminal : returns whether a given state is terminal
* state_info : optional , returns a dict of info about a given state
* step_info : optional , returns a dict of info about a given ( state , action , next_state ) tuple
2022-11-18 22:25:33 +01:00
The class - based structure serves the purpose of allowing environment constants to be defined in the class ,
and then using them by name in the code itself .
For the moment , this is predominantly for internal use . This API is likely to change , but in the future
we intend to flesh it out and officially expose it to end users .
"""
2023-02-12 07:49:37 -05:00
observation_space : Space
action_space : Space
2022-11-29 23:37:53 +00:00
def __init__ ( self , options : dict [ str , Any ] | None = None ) :
2022-11-18 22:25:33 +01:00
""" Initialize the environment constants. """
self . __dict__ . update ( options or { } )
2023-12-17 15:03:06 +01:00
self . default_params = self . get_default_params ( )
2022-11-18 22:25:33 +01:00
2023-12-17 15:03:06 +01:00
def initial ( self , rng : Any , params : Params | None = None ) - > StateType :
2023-11-07 13:27:25 +00:00
""" Generates the initial state of the environment with a random number generator. """
2022-11-18 22:25:33 +01:00
raise NotImplementedError
2023-12-17 15:03:06 +01:00
def transition (
self , state : StateType , action : ActType , rng : Any , params : Params | None = None
) - > StateType :
2023-11-07 13:27:25 +00:00
""" Updates (transitions) the state with an action and random number generator. """
2022-11-18 22:25:33 +01:00
raise NotImplementedError
2024-06-07 20:16:38 +00:00
def observation (
self , state : StateType , rng : Any , params : Params | None = None
) - > ObsType :
2023-11-07 13:27:25 +00:00
""" Generates an observation for a given state of an environment. """
2022-11-29 23:37:53 +00:00
raise NotImplementedError
2022-11-18 22:25:33 +01:00
def reward (
2023-12-17 15:03:06 +01:00
self ,
state : StateType ,
action : ActType ,
next_state : StateType ,
2024-06-07 20:16:38 +00:00
rng : Any ,
2023-12-17 15:03:06 +01:00
params : Params | None = None ,
2022-11-18 22:25:33 +01:00
) - > RewardType :
2023-11-07 13:27:25 +00:00
""" Computes the reward for a given transition between `state`, `action` to `next_state`. """
2022-11-18 22:25:33 +01:00
raise NotImplementedError
2024-06-07 20:16:38 +00:00
def terminal (
self , state : StateType , rng : Any , params : Params | None = None
) - > TerminalType :
2023-11-07 13:27:25 +00:00
""" Returns if the state is a final terminal state. """
2022-11-18 22:25:33 +01:00
raise NotImplementedError
2023-12-25 22:55:06 +02:00
def state_info ( self , state : StateType , params : Params | None = None ) - > dict :
2022-11-18 22:25:33 +01:00
""" Info dict about a single state. """
return { }
2023-11-07 13:27:25 +00:00
def transition_info (
2023-12-17 15:03:06 +01:00
self ,
state : StateType ,
action : ActType ,
next_state : StateType ,
params : Params | None = None ,
2022-11-18 22:25:33 +01:00
) - > dict :
""" Info dict about a full transition. """
return { }
def transform ( self , func : Callable [ [ Callable ] , Callable ] ) :
""" Functional transformations. """
self . initial = func ( self . initial )
self . transition = func ( self . transition )
self . observation = func ( self . observation )
self . reward = func ( self . reward )
self . terminal = func ( self . terminal )
2023-12-25 22:55:06 +02:00
self . state_info = func ( self . state_info )
2023-12-17 15:03:06 +01:00
self . step_info = func ( self . transition_info )
2022-11-18 22:25:33 +01:00
def render_image (
2023-12-17 15:03:06 +01:00
self ,
state : StateType ,
render_state : RenderStateType ,
params : Params | None = None ,
2022-11-29 23:37:53 +00:00
) - > tuple [ RenderStateType , np . ndarray ] :
2022-11-18 22:25:33 +01:00
""" Show the state. """
raise NotImplementedError
2024-06-07 20:16:38 +00:00
def render_init ( self , params : Params | None = None , * * kwargs ) - > RenderStateType :
2022-11-18 22:25:33 +01:00
""" Initialize the render state. """
raise NotImplementedError
2024-06-07 20:16:38 +00:00
def render_close ( self , render_state : RenderStateType , params : Params | None = None ) :
2022-11-18 22:25:33 +01:00
""" Close the render state. """
raise NotImplementedError
2023-12-17 15:03:06 +01:00
def get_default_params ( self , * * kwargs ) - > Params | None :
""" Get the default params. """
return None