Module pearl.utils.instantiations.environments.gym_environment

Expand source code
import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import numpy as np
from pearl.api.action import Action
from pearl.api.action_result import ActionResult
from pearl.api.action_space import ActionSpace
from pearl.api.environment import Environment
from pearl.api.observation import Observation
from pearl.api.space import Space
from pearl.utils.instantiations.spaces.box import BoxSpace
from pearl.utils.instantiations.spaces.box_action import BoxActionSpace
from pearl.utils.instantiations.spaces.discrete import DiscreteSpace
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from torch import Tensor

try:
    import gymnasium as gym

    logging.info("Using 'gymnasium' package.")
except ModuleNotFoundError:
    import gym

    logging.warning("Using deprecated 'gym' package.")


def single_element_tensor_to_int(x: Tensor) -> int:
    return int(x)


def tensor_to_numpy(x: Tensor) -> np.ndarray:
    return x.numpy(force=True)


GYM_TO_PEARL_ACTION_SPACE = {
    "Discrete": DiscreteActionSpace,
    "Box": BoxActionSpace
    # Add more here as needed
}
GYM_TO_PEARL_OBSERVATION_SPACE = {
    "Discrete": DiscreteSpace,
    "Box": BoxSpace,
    # Add more here as needed
}
PEARL_TO_GYM_ACTION = {
    "Discrete": single_element_tensor_to_int,
    "Box": tensor_to_numpy,
    # Add more here as needed
}


class GymEnvironment(Environment):
    """A wrapper for `gym.Env` to behave like Pearl's `Environment`."""

    def __init__(
        self, env_or_env_name: Union[gym.Env, str], *args: Any, **kwargs: Any
    ) -> None:
        """Constructs a `GymEnvironment` wrapper.

        Args:
            env_or_env_name: A gym.Env instance or a name of a gym.Env.
            args: Arguments passed to `gym.make()` if the first argument is a string.
            kwargs: Keyword arguments passed to `gym.make()` if the first argument is a string.
        """
        if type(env_or_env_name) is str:
            env = gym.make(env_or_env_name, *args, **kwargs)
        else:
            env = env_or_env_name
        self.env: gym.Env = env
        self._action_space: ActionSpace = _get_pearl_space(
            gym_space=self.env.action_space,
            gym_to_pearl_map=GYM_TO_PEARL_ACTION_SPACE,
        )
        self._observation_space: Space = _get_pearl_space(
            gym_space=self.env.observation_space,
            gym_to_pearl_map=GYM_TO_PEARL_OBSERVATION_SPACE,
        )

    @property
    def action_space(self) -> ActionSpace:
        """Returns the Pearl action space for this environment."""
        return self._action_space

    @property
    def observation_space(self) -> Space:
        return self._observation_space

    def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]:
        """Resets the environment and returns the initial observation and
        initial action space."""
        # pyre-fixme: ActionSpace does not have _gym_space
        # FIXME: private attribute _gym_space should not be accessed
        self._action_space._gym_space.seed(seed)
        self.env.action_space.seed(seed)
        reset_result = self.env.reset()
        if isinstance(reset_result, Iterable) and isinstance(reset_result[1], dict):
            # newer Gym versions return an info dict.
            observation, info = self.env.reset(seed=seed)
        else:
            # TODO: Deprecate this part at some point and only support new
            # version of Gymnasium?
            observation = list(reset_result.values())[0]  # pyre-ignore
        if isinstance(observation, np.ndarray):
            observation = observation.astype(np.float32)
        return observation, self.action_space

    def step(self, action: Action) -> ActionResult:
        """Takes one step in the environment given the agent's action. Returns an
        `ActionResult` object containing the next observation, reward, and done flag."""
        # Convert action to the format expected by Gymnasium
        effective_action = _get_gym_action(
            pearl_action=action, gym_space=self.env.action_space
        )
        # Take a step in the environment and receive an action result
        gym_action_result = self.env.step(effective_action)
        if len(gym_action_result) == 4:
            # Older Gym versions use 'done' as opposed to 'terminated' and 'truncated'
            observation, reward, done, info = gym_action_result  # pyre-ignore
            terminated = done
            truncated = False
        elif len(gym_action_result) == 5:
            # Newer Gym versions use 'terminated' and 'truncated'
            observation, reward, terminated, truncated, info = gym_action_result
        else:
            raise ValueError(
                f"Unexpected action result from Gym (expected 4 or 5 elements): {gym_action_result}"
            )
        if "cost" in info.keys():
            cost = info["cost"]
        else:
            cost = None

        if "available_action_space" in info.keys():
            available_action_space = info["available_action_space"]
        else:
            available_action_space = None

        if isinstance(observation, np.ndarray):
            observation = observation.astype(np.float32)
        if isinstance(reward, np.float64):
            reward = reward.astype(np.float32)
        if isinstance(cost, np.float64):
            cost = cost.astype(np.float32)

        return ActionResult(
            observation=observation,
            reward=reward,
            terminated=terminated,
            truncated=truncated,
            info=info,
            cost=cost,
            available_action_space=available_action_space,
        )

    def render(self) -> None:
        self.env.render()

    def close(self) -> None:
        self.env.close()

    def __str__(self) -> str:
        return self.env.spec.id


def _get_gym_action(
    pearl_action: Action, gym_space: gym.Space
) -> Union[int, np.ndarray]:
    """A helper function to convert a Pearl `Action` to an action compatible with
    the Gym action space `gym_space`."""
    gym_space_name = gym_space.__class__.__name__
    try:
        pearl_to_gym_action_transform = PEARL_TO_GYM_ACTION[gym_space_name]
    except KeyError:
        raise NotImplementedError(
            f"The Gym space '{gym_space_name}' is not yet supported in Pearl."
        )
    return pearl_to_gym_action_transform(pearl_action)


def _get_pearl_space(
    gym_space: gym.Space, gym_to_pearl_map: Dict[str, Any]
) -> ActionSpace:
    """Returns the Pearl action space for this environment."""
    gym_space_name = gym_space.__class__.__name__
    try:
        pearl_action_space_cls = gym_to_pearl_map[gym_space_name]
    except KeyError:
        raise NotImplementedError(
            f"The Gym space '{gym_space_name}' is not yet supported in Pearl."
        )
    return pearl_action_space_cls.from_gym(gym_space)

Functions

def single_element_tensor_to_int(x: torch.Tensor) ‑> int
Expand source code
def single_element_tensor_to_int(x: Tensor) -> int:
    return int(x)
def tensor_to_numpy(x: torch.Tensor) ‑> numpy.ndarray
Expand source code
def tensor_to_numpy(x: Tensor) -> np.ndarray:
    return x.numpy(force=True)

Classes

class GymEnvironment (env_or_env_name: Union[gymnasium.core.Env, str], *args: Any, **kwargs: Any)

A wrapper for gym.Env to behave like Pearl's Environment.

Constructs a GymEnvironment wrapper.

Args

env_or_env_name
A gym.Env instance or a name of a gym.Env.
args
Arguments passed to gym.make() if the first argument is a string.
kwargs
Keyword arguments passed to gym.make() if the first argument is a string.
Expand source code
class GymEnvironment(Environment):
    """A wrapper for `gym.Env` to behave like Pearl's `Environment`."""

    def __init__(
        self, env_or_env_name: Union[gym.Env, str], *args: Any, **kwargs: Any
    ) -> None:
        """Constructs a `GymEnvironment` wrapper.

        Args:
            env_or_env_name: A gym.Env instance or a name of a gym.Env.
            args: Arguments passed to `gym.make()` if the first argument is a string.
            kwargs: Keyword arguments passed to `gym.make()` if the first argument is a string.
        """
        if type(env_or_env_name) is str:
            env = gym.make(env_or_env_name, *args, **kwargs)
        else:
            env = env_or_env_name
        self.env: gym.Env = env
        self._action_space: ActionSpace = _get_pearl_space(
            gym_space=self.env.action_space,
            gym_to_pearl_map=GYM_TO_PEARL_ACTION_SPACE,
        )
        self._observation_space: Space = _get_pearl_space(
            gym_space=self.env.observation_space,
            gym_to_pearl_map=GYM_TO_PEARL_OBSERVATION_SPACE,
        )

    @property
    def action_space(self) -> ActionSpace:
        """Returns the Pearl action space for this environment."""
        return self._action_space

    @property
    def observation_space(self) -> Space:
        return self._observation_space

    def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]:
        """Resets the environment and returns the initial observation and
        initial action space."""
        # pyre-fixme: ActionSpace does not have _gym_space
        # FIXME: private attribute _gym_space should not be accessed
        self._action_space._gym_space.seed(seed)
        self.env.action_space.seed(seed)
        reset_result = self.env.reset()
        if isinstance(reset_result, Iterable) and isinstance(reset_result[1], dict):
            # newer Gym versions return an info dict.
            observation, info = self.env.reset(seed=seed)
        else:
            # TODO: Deprecate this part at some point and only support new
            # version of Gymnasium?
            observation = list(reset_result.values())[0]  # pyre-ignore
        if isinstance(observation, np.ndarray):
            observation = observation.astype(np.float32)
        return observation, self.action_space

    def step(self, action: Action) -> ActionResult:
        """Takes one step in the environment given the agent's action. Returns an
        `ActionResult` object containing the next observation, reward, and done flag."""
        # Convert action to the format expected by Gymnasium
        effective_action = _get_gym_action(
            pearl_action=action, gym_space=self.env.action_space
        )
        # Take a step in the environment and receive an action result
        gym_action_result = self.env.step(effective_action)
        if len(gym_action_result) == 4:
            # Older Gym versions use 'done' as opposed to 'terminated' and 'truncated'
            observation, reward, done, info = gym_action_result  # pyre-ignore
            terminated = done
            truncated = False
        elif len(gym_action_result) == 5:
            # Newer Gym versions use 'terminated' and 'truncated'
            observation, reward, terminated, truncated, info = gym_action_result
        else:
            raise ValueError(
                f"Unexpected action result from Gym (expected 4 or 5 elements): {gym_action_result}"
            )
        if "cost" in info.keys():
            cost = info["cost"]
        else:
            cost = None

        if "available_action_space" in info.keys():
            available_action_space = info["available_action_space"]
        else:
            available_action_space = None

        if isinstance(observation, np.ndarray):
            observation = observation.astype(np.float32)
        if isinstance(reward, np.float64):
            reward = reward.astype(np.float32)
        if isinstance(cost, np.float64):
            cost = cost.astype(np.float32)

        return ActionResult(
            observation=observation,
            reward=reward,
            terminated=terminated,
            truncated=truncated,
            info=info,
            cost=cost,
            available_action_space=available_action_space,
        )

    def render(self) -> None:
        self.env.render()

    def close(self) -> None:
        self.env.close()

    def __str__(self) -> str:
        return self.env.spec.id

Ancestors

Instance variables

var action_spaceActionSpace

Returns the Pearl action space for this environment.

Expand source code
@property
def action_space(self) -> ActionSpace:
    """Returns the Pearl action space for this environment."""
    return self._action_space
var observation_spaceSpace
Expand source code
@property
def observation_space(self) -> Space:
    return self._observation_space

Methods

def reset(self, seed: Optional[int] = None) ‑> Tuple[object, ActionSpace]

Resets the environment and returns the initial observation and initial action space.

Expand source code
def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]:
    """Resets the environment and returns the initial observation and
    initial action space."""
    # pyre-fixme: ActionSpace does not have _gym_space
    # FIXME: private attribute _gym_space should not be accessed
    self._action_space._gym_space.seed(seed)
    self.env.action_space.seed(seed)
    reset_result = self.env.reset()
    if isinstance(reset_result, Iterable) and isinstance(reset_result[1], dict):
        # newer Gym versions return an info dict.
        observation, info = self.env.reset(seed=seed)
    else:
        # TODO: Deprecate this part at some point and only support new
        # version of Gymnasium?
        observation = list(reset_result.values())[0]  # pyre-ignore
    if isinstance(observation, np.ndarray):
        observation = observation.astype(np.float32)
    return observation, self.action_space

Inherited members