Module pearl.utils.instantiations.spaces.discrete
Expand source code
from __future__ import annotations
import logging
from typing import Iterator, List, Optional
import torch
from pearl.api.space import Space
from torch import Tensor
try:
import gymnasium as gym
from gymnasium.spaces import Discrete
logging.info("Using 'gymnasium' package.")
except ModuleNotFoundError:
import gym
from gym.spaces import Discrete
logging.warning("Using deprecated 'gym' package.")
class DiscreteSpace(Space):
"""A discrete space containing finitely many elements.
This class makes use of the `Discrete` space from Gymnasium, but uses an
arbitrary list of Tensor objects instead of a range of integers.
`DiscreteSpace` is also based on PyTorch tensors instead of NumPy arrays.
"""
def __init__(self, elements: List[Tensor], seed: Optional[int] = None) -> None:
"""Contructs a `DiscreteSpace`.
Args:
elements: A list of Tensors representing the elements of the space.
seed: Random seed used to initialize the random number generator of the
underlying Gym `Discrete` space.
"""
if len(elements) == 0:
raise ValueError("`DiscreteSpace` requires at least one element.")
self._set_validated_elements(elements=elements) # sets self.elements
self._gym_space = Discrete(n=len(elements), seed=seed, start=0)
def _set_validated_elements(self, elements: List[Tensor]) -> None:
"""Creates the set of elements after validating that they all have the
same shape."""
# Use the first shape to determine the expected shape.
validated = []
expected_shape = elements[0].shape
for e in elements:
if e.shape != expected_shape:
raise ValueError(
f"All elements must have the same shape. Expected {expected_shape}, "
f"but got {e.shape}."
)
validated.append(e)
self.elements = validated
@property
def n(self) -> int:
"""Returns the number of elements in this space."""
return self._gym_space.n
@property
def is_continuous(self) -> bool:
"""Checks whether this is a continuous space."""
return False
@property
def shape(self) -> torch.Size:
"""Returns the shape of an element of the space."""
return self.elements[0].shape
def sample(self, mask: Optional[Tensor] = None) -> Tensor:
"""Sample an element uniformly at random from this space.
Args:
mask: An optional Tensor of shape `n` specifying the set of available
elements, where `1` represents valid elements and `0` invalid elements.
This mask is passed to Gymnasium's `Discrete.sample` method. If no
elements are available, `self.elements[0]` is returned.
Returns:
A randomly sampled (available) element.
"""
mask_np = mask.numpy().astype(int) if mask is not None else None
idx = self._gym_space.sample(mask=mask_np)
return self.elements[idx]
def __iter__(self) -> Iterator[Tensor]:
for e in self.elements:
yield e
def __getitem__(self, index: int) -> Tensor:
return self.elements[index]
@staticmethod
def from_gym(gym_space: gym.Space) -> DiscreteSpace:
"""Constructs a `DiscreteSpace` given a Gymnasium `Discrete` space.
Convert from Gymnasium's index set {start, start + n - 1} to a list
of tensors:
[torch.tensor([start]), ..., torch.tensor([start + n - 1])],
in accordance to what is expected by `DiscreteSpace`.
Args:
gym_space: A Gymnasium `Discrete` space.
Returns:
A `DiscreteSpace` with the same number of elements as `gym_space`.
"""
assert isinstance(gym_space, Discrete)
start, n = gym_space.start, gym_space.n
return DiscreteSpace(
elements=list(torch.arange(start=start, end=start + n).view(-1, 1)),
seed=gym_space._np_random,
)
Classes
class DiscreteSpace (elements: List[Tensor], seed: Optional[int] = None)
-
A discrete space containing finitely many elements.
This class makes use of the
Discrete
space from Gymnasium, but uses an arbitrary list of Tensor objects instead of a range of integers.DiscreteSpace
is also based on PyTorch tensors instead of NumPy arrays.Contructs a
DiscreteSpace
.Args
elements
- A list of Tensors representing the elements of the space.
seed
- Random seed used to initialize the random number generator of the
underlying Gym
Discrete
space.
Expand source code
class DiscreteSpace(Space): """A discrete space containing finitely many elements. This class makes use of the `Discrete` space from Gymnasium, but uses an arbitrary list of Tensor objects instead of a range of integers. `DiscreteSpace` is also based on PyTorch tensors instead of NumPy arrays. """ def __init__(self, elements: List[Tensor], seed: Optional[int] = None) -> None: """Contructs a `DiscreteSpace`. Args: elements: A list of Tensors representing the elements of the space. seed: Random seed used to initialize the random number generator of the underlying Gym `Discrete` space. """ if len(elements) == 0: raise ValueError("`DiscreteSpace` requires at least one element.") self._set_validated_elements(elements=elements) # sets self.elements self._gym_space = Discrete(n=len(elements), seed=seed, start=0) def _set_validated_elements(self, elements: List[Tensor]) -> None: """Creates the set of elements after validating that they all have the same shape.""" # Use the first shape to determine the expected shape. validated = [] expected_shape = elements[0].shape for e in elements: if e.shape != expected_shape: raise ValueError( f"All elements must have the same shape. Expected {expected_shape}, " f"but got {e.shape}." ) validated.append(e) self.elements = validated @property def n(self) -> int: """Returns the number of elements in this space.""" return self._gym_space.n @property def is_continuous(self) -> bool: """Checks whether this is a continuous space.""" return False @property def shape(self) -> torch.Size: """Returns the shape of an element of the space.""" return self.elements[0].shape def sample(self, mask: Optional[Tensor] = None) -> Tensor: """Sample an element uniformly at random from this space. Args: mask: An optional Tensor of shape `n` specifying the set of available elements, where `1` represents valid elements and `0` invalid elements. This mask is passed to Gymnasium's `Discrete.sample` method. If no elements are available, `self.elements[0]` is returned. Returns: A randomly sampled (available) element. """ mask_np = mask.numpy().astype(int) if mask is not None else None idx = self._gym_space.sample(mask=mask_np) return self.elements[idx] def __iter__(self) -> Iterator[Tensor]: for e in self.elements: yield e def __getitem__(self, index: int) -> Tensor: return self.elements[index] @staticmethod def from_gym(gym_space: gym.Space) -> DiscreteSpace: """Constructs a `DiscreteSpace` given a Gymnasium `Discrete` space. Convert from Gymnasium's index set {start, start + n - 1} to a list of tensors: [torch.tensor([start]), ..., torch.tensor([start + n - 1])], in accordance to what is expected by `DiscreteSpace`. Args: gym_space: A Gymnasium `Discrete` space. Returns: A `DiscreteSpace` with the same number of elements as `gym_space`. """ assert isinstance(gym_space, Discrete) start, n = gym_space.start, gym_space.n return DiscreteSpace( elements=list(torch.arange(start=start, end=start + n).view(-1, 1)), seed=gym_space._np_random, )
Ancestors
- Space
- abc.ABC
Subclasses
Static methods
def from_gym(gym_space: gym.Space) ‑> DiscreteSpace
-
Constructs a
DiscreteSpace
given a GymnasiumDiscrete
space. Convert from Gymnasium's index set {start, start + n - 1} to a list of tensors: [torch.tensor([start]), …, torch.tensor([start + n - 1])], in accordance to what is expected byDiscreteSpace
.Args
gym_space
- A Gymnasium
Discrete
space.
Returns
A
DiscreteSpace
with the same number of elements asgym_space
.Expand source code
@staticmethod def from_gym(gym_space: gym.Space) -> DiscreteSpace: """Constructs a `DiscreteSpace` given a Gymnasium `Discrete` space. Convert from Gymnasium's index set {start, start + n - 1} to a list of tensors: [torch.tensor([start]), ..., torch.tensor([start + n - 1])], in accordance to what is expected by `DiscreteSpace`. Args: gym_space: A Gymnasium `Discrete` space. Returns: A `DiscreteSpace` with the same number of elements as `gym_space`. """ assert isinstance(gym_space, Discrete) start, n = gym_space.start, gym_space.n return DiscreteSpace( elements=list(torch.arange(start=start, end=start + n).view(-1, 1)), seed=gym_space._np_random, )
Instance variables
var n : int
-
Returns the number of elements in this space.
Expand source code
@property def n(self) -> int: """Returns the number of elements in this space.""" return self._gym_space.n
Methods
def sample(self, mask: Optional[Tensor] = None) ‑> torch.Tensor
-
Sample an element uniformly at random from this space.
Args
mask
- An optional Tensor of shape
n
specifying the set of available elements, where1
represents valid elements and0
invalid elements. This mask is passed to Gymnasium'sDiscrete.sample
method. If no elements are available,self.elements[0]
is returned.
Returns
A randomly sampled (available) element.
Expand source code
def sample(self, mask: Optional[Tensor] = None) -> Tensor: """Sample an element uniformly at random from this space. Args: mask: An optional Tensor of shape `n` specifying the set of available elements, where `1` represents valid elements and `0` invalid elements. This mask is passed to Gymnasium's `Discrete.sample` method. If no elements are available, `self.elements[0]` is returned. Returns: A randomly sampled (available) element. """ mask_np = mask.numpy().astype(int) if mask is not None else None idx = self._gym_space.sample(mask=mask_np) return self.elements[idx]
Inherited members