Module pearl.policy_learners.policy_learner

Expand source code
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, TypeVar

import torch
from pearl.action_representation_modules.action_representation_module import (
    ActionRepresentationModule,
)
from pearl.action_representation_modules.identity_action_representation_module import (
    IdentityActionRepresentationModule,
)
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.history_summarization_modules.history_summarization_module import (
    HistorySummarizationModule,
    SubjectiveState,
)
from pearl.history_summarization_modules.identity_history_summarization_module import (
    IdentityHistorySummarizationModule,
)
from pearl.policy_learners.exploration_modules.common.no_exploration import (
    NoExploration,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
    ExplorationModule,
)
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.device import is_distribution_enabled
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace


class PolicyLearner(torch.nn.Module, ABC):
    """
    An abstract interface for policy learners.

    Important requirements for policy learners using tensors:
        1. Attribute `requires_tensors` must be `True` (this is the default).
        2. If a policy learner is to operate on a given torch device,
           the policy learner must be moved to that device using method `to(device)`.
        3. All inputs to policy leaners must be moved to the proper device,
           including `TransitionBatch`es (which also have a `to(device)` method).
    """

    # See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use  # noqa E501
    # of `T` to annotate `self`. At least one method of `PolicyLearner`
    # returns `self` and we want those return values to be
    # the type of the subclass, not the looser type of `PolicyLearner`.
    T = TypeVar("T", bound="PolicyLearner")

    def __init__(
        self,
        on_policy: bool,
        is_action_continuous: bool,
        action_space: Optional[ActionSpace] = None,
        training_rounds: int = 100,
        batch_size: int = 1,
        requires_tensors: bool = True,
        action_representation_module: Optional[ActionRepresentationModule] = None,
        **options: Any,
    ) -> None:
        super(PolicyLearner, self).__init__()

        self._exploration_module: ExplorationModule = (
            options["exploration_module"]
            if "exploration_module" in options
            else NoExploration()
        )

        # User needs to either provide the action space or an action representation module at
        # policy learner's initialization for sequential decision making.
        if action_representation_module is None:
            if action_space is not None:
                # If a policy learner is initialized with an action space, then we assume that
                # the agent does not need dynamic action space support.
                self._action_representation_module: ActionRepresentationModule = (
                    IdentityActionRepresentationModule(
                        max_number_actions=action_space.n
                        if isinstance(action_space, DiscreteActionSpace)
                        else -1,
                        representation_dim=action_space.action_dim,
                    )
                )
            else:
                # This is only used in the case of bandit learning applications.
                # TODO: add action representation module for bandit learning applications.
                self._action_representation_module = (
                    IdentityActionRepresentationModule()
                )
        else:
            # User needs to at least specify action dimensions if no action_space is provided.
            assert action_representation_module.representation_dim != -1
            self._action_representation_module = action_representation_module

        self._history_summarization_module: HistorySummarizationModule = (
            IdentityHistorySummarizationModule()
        )

        self._training_rounds = training_rounds
        self._batch_size = batch_size
        self._training_steps = 0
        self.on_policy = on_policy
        self.is_action_continuous = is_action_continuous
        self.distribution_enabled: bool = is_distribution_enabled()
        self.requires_tensors = requires_tensors

    @property
    def batch_size(self) -> int:
        return self._batch_size

    @property
    def exploration_module(self) -> ExplorationModule:
        return self._exploration_module

    @property
    def action_representation_module(self) -> ActionRepresentationModule:
        return self._action_representation_module

    @exploration_module.setter
    def exploration_module(self, new_exploration_module: ExplorationModule) -> None:
        self._exploration_module = new_exploration_module

    def get_action_representation_module(self) -> ActionRepresentationModule:
        return self._action_representation_module

    def set_history_summarization_module(
        self, value: HistorySummarizationModule
    ) -> None:
        self._history_summarization_module = value

    def reset(self, action_space: ActionSpace) -> None:
        """Resets policy maker for a new episode. Default implementation does nothing."""
        pass

    @abstractmethod
    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: ActionSpace,
        exploit: bool = False,
    ) -> Action:
        pass

    def learn(
        self,
        replay_buffer: ReplayBuffer,
    ) -> Dict[str, Any]:
        """
        Args:
            replay_buffer: buffer instance which learn is reading from

        Returns:
            A dictionary which includes useful metrics
        """
        batch_size = self._batch_size if not self.on_policy else len(replay_buffer)

        if len(replay_buffer) < batch_size or len(replay_buffer) == 0:
            return {}

        report = {}
        for _ in range(self._training_rounds):
            self._training_steps += 1
            batch = replay_buffer.sample(batch_size)
            single_report = {}
            if isinstance(batch, TransitionBatch):
                batch = self.preprocess_batch(batch)
                single_report = self.learn_batch(batch)

            for k, v in single_report.items():
                if k in report:
                    report[k].append(v)
                else:
                    report[k] = [v]
        return report

    def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch:
        """
        Processes a batch of transitions before passing it to learn_batch().
        This function can be used to implement preprocessing steps such as
        transform the actions.
        """
        batch.state = self._history_summarization_module(batch.state)
        with torch.no_grad():
            batch.next_state = self._history_summarization_module(batch.next_state)

        batch.action = self._action_representation_module(batch.action)
        if batch.next_action is not None:
            batch.next_action = self._action_representation_module(batch.next_action)
        if batch.curr_available_actions is not None:
            batch.curr_available_actions = self._action_representation_module(
                batch.curr_available_actions
            )
        if batch.next_available_actions is not None:
            batch.next_available_actions = self._action_representation_module(
                batch.next_available_actions
            )
        return batch

    @abstractmethod
    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        """
        Args:
            batch: batch of data that agent is learning from

        Returns:
            A dictionary which includes useful metrics
        """
        raise NotImplementedError("learn_batch is not implemented")

    def __str__(self) -> str:
        return self.__class__.__name__


class DistributionalPolicyLearner(PolicyLearner):
    """
    An abstract interface for distributional policy learners.
    Enforces the property of a risk sensitive safety module.
    """

    def __init__(
        self,
        on_policy: bool,
        is_action_continuous: bool,
        training_rounds: int = 100,
        batch_size: int = 1,
        action_representation_module: Optional[ActionRepresentationModule] = None,
        **options: Any,
    ) -> None:

        super(DistributionalPolicyLearner, self).__init__(
            on_policy=on_policy,
            is_action_continuous=is_action_continuous,
            training_rounds=training_rounds,
            batch_size=batch_size,
            action_representation_module=action_representation_module,
            **options,
        )

Classes

class DistributionalPolicyLearner (on_policy: bool, is_action_continuous: bool, training_rounds: int = 100, batch_size: int = 1, action_representation_module: Optional[ActionRepresentationModule] = None, **options: Any)

An abstract interface for distributional policy learners. Enforces the property of a risk sensitive safety module.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class DistributionalPolicyLearner(PolicyLearner):
    """
    An abstract interface for distributional policy learners.
    Enforces the property of a risk sensitive safety module.
    """

    def __init__(
        self,
        on_policy: bool,
        is_action_continuous: bool,
        training_rounds: int = 100,
        batch_size: int = 1,
        action_representation_module: Optional[ActionRepresentationModule] = None,
        **options: Any,
    ) -> None:

        super(DistributionalPolicyLearner, self).__init__(
            on_policy=on_policy,
            is_action_continuous=is_action_continuous,
            training_rounds=training_rounds,
            batch_size=batch_size,
            action_representation_module=action_representation_module,
            **options,
        )

Ancestors

Subclasses

Inherited members

class PolicyLearner (on_policy: bool, is_action_continuous: bool, action_space: Optional[ActionSpace] = None, training_rounds: int = 100, batch_size: int = 1, requires_tensors: bool = True, action_representation_module: Optional[ActionRepresentationModule] = None, **options: Any)

An abstract interface for policy learners.

Important requirements for policy learners using tensors: 1. Attribute requires_tensors must be True (this is the default). 2. If a policy learner is to operate on a given torch device, the policy learner must be moved to that device using method to(device). 3. All inputs to policy leaners must be moved to the proper device, including TransitionBatches (which also have a to(device) method).

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class PolicyLearner(torch.nn.Module, ABC):
    """
    An abstract interface for policy learners.

    Important requirements for policy learners using tensors:
        1. Attribute `requires_tensors` must be `True` (this is the default).
        2. If a policy learner is to operate on a given torch device,
           the policy learner must be moved to that device using method `to(device)`.
        3. All inputs to policy leaners must be moved to the proper device,
           including `TransitionBatch`es (which also have a `to(device)` method).
    """

    # See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use  # noqa E501
    # of `T` to annotate `self`. At least one method of `PolicyLearner`
    # returns `self` and we want those return values to be
    # the type of the subclass, not the looser type of `PolicyLearner`.
    T = TypeVar("T", bound="PolicyLearner")

    def __init__(
        self,
        on_policy: bool,
        is_action_continuous: bool,
        action_space: Optional[ActionSpace] = None,
        training_rounds: int = 100,
        batch_size: int = 1,
        requires_tensors: bool = True,
        action_representation_module: Optional[ActionRepresentationModule] = None,
        **options: Any,
    ) -> None:
        super(PolicyLearner, self).__init__()

        self._exploration_module: ExplorationModule = (
            options["exploration_module"]
            if "exploration_module" in options
            else NoExploration()
        )

        # User needs to either provide the action space or an action representation module at
        # policy learner's initialization for sequential decision making.
        if action_representation_module is None:
            if action_space is not None:
                # If a policy learner is initialized with an action space, then we assume that
                # the agent does not need dynamic action space support.
                self._action_representation_module: ActionRepresentationModule = (
                    IdentityActionRepresentationModule(
                        max_number_actions=action_space.n
                        if isinstance(action_space, DiscreteActionSpace)
                        else -1,
                        representation_dim=action_space.action_dim,
                    )
                )
            else:
                # This is only used in the case of bandit learning applications.
                # TODO: add action representation module for bandit learning applications.
                self._action_representation_module = (
                    IdentityActionRepresentationModule()
                )
        else:
            # User needs to at least specify action dimensions if no action_space is provided.
            assert action_representation_module.representation_dim != -1
            self._action_representation_module = action_representation_module

        self._history_summarization_module: HistorySummarizationModule = (
            IdentityHistorySummarizationModule()
        )

        self._training_rounds = training_rounds
        self._batch_size = batch_size
        self._training_steps = 0
        self.on_policy = on_policy
        self.is_action_continuous = is_action_continuous
        self.distribution_enabled: bool = is_distribution_enabled()
        self.requires_tensors = requires_tensors

    @property
    def batch_size(self) -> int:
        return self._batch_size

    @property
    def exploration_module(self) -> ExplorationModule:
        return self._exploration_module

    @property
    def action_representation_module(self) -> ActionRepresentationModule:
        return self._action_representation_module

    @exploration_module.setter
    def exploration_module(self, new_exploration_module: ExplorationModule) -> None:
        self._exploration_module = new_exploration_module

    def get_action_representation_module(self) -> ActionRepresentationModule:
        return self._action_representation_module

    def set_history_summarization_module(
        self, value: HistorySummarizationModule
    ) -> None:
        self._history_summarization_module = value

    def reset(self, action_space: ActionSpace) -> None:
        """Resets policy maker for a new episode. Default implementation does nothing."""
        pass

    @abstractmethod
    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: ActionSpace,
        exploit: bool = False,
    ) -> Action:
        pass

    def learn(
        self,
        replay_buffer: ReplayBuffer,
    ) -> Dict[str, Any]:
        """
        Args:
            replay_buffer: buffer instance which learn is reading from

        Returns:
            A dictionary which includes useful metrics
        """
        batch_size = self._batch_size if not self.on_policy else len(replay_buffer)

        if len(replay_buffer) < batch_size or len(replay_buffer) == 0:
            return {}

        report = {}
        for _ in range(self._training_rounds):
            self._training_steps += 1
            batch = replay_buffer.sample(batch_size)
            single_report = {}
            if isinstance(batch, TransitionBatch):
                batch = self.preprocess_batch(batch)
                single_report = self.learn_batch(batch)

            for k, v in single_report.items():
                if k in report:
                    report[k].append(v)
                else:
                    report[k] = [v]
        return report

    def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch:
        """
        Processes a batch of transitions before passing it to learn_batch().
        This function can be used to implement preprocessing steps such as
        transform the actions.
        """
        batch.state = self._history_summarization_module(batch.state)
        with torch.no_grad():
            batch.next_state = self._history_summarization_module(batch.next_state)

        batch.action = self._action_representation_module(batch.action)
        if batch.next_action is not None:
            batch.next_action = self._action_representation_module(batch.next_action)
        if batch.curr_available_actions is not None:
            batch.curr_available_actions = self._action_representation_module(
                batch.curr_available_actions
            )
        if batch.next_available_actions is not None:
            batch.next_available_actions = self._action_representation_module(
                batch.next_available_actions
            )
        return batch

    @abstractmethod
    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        """
        Args:
            batch: batch of data that agent is learning from

        Returns:
            A dictionary which includes useful metrics
        """
        raise NotImplementedError("learn_batch is not implemented")

    def __str__(self) -> str:
        return self.__class__.__name__

Ancestors

  • torch.nn.modules.module.Module
  • abc.ABC

Subclasses

Class variables

var T

Instance variables

var action_representation_moduleActionRepresentationModule
Expand source code
@property
def action_representation_module(self) -> ActionRepresentationModule:
    return self._action_representation_module
var batch_size : int
Expand source code
@property
def batch_size(self) -> int:
    return self._batch_size
var exploration_moduleExplorationModule
Expand source code
@property
def exploration_module(self) -> ExplorationModule:
    return self._exploration_module

Methods

def act(self, subjective_state: torch.Tensor, available_action_space: ActionSpace, exploit: bool = False) ‑> torch.Tensor
Expand source code
@abstractmethod
def act(
    self,
    subjective_state: SubjectiveState,
    available_action_space: ActionSpace,
    exploit: bool = False,
) -> Action:
    pass
def get_action_representation_module(self) ‑> ActionRepresentationModule
Expand source code
def get_action_representation_module(self) -> ActionRepresentationModule:
    return self._action_representation_module
def learn(self, replay_buffer: ReplayBuffer) ‑> Dict[str, Any]

Args

replay_buffer
buffer instance which learn is reading from

Returns

A dictionary which includes useful metrics

Expand source code
def learn(
    self,
    replay_buffer: ReplayBuffer,
) -> Dict[str, Any]:
    """
    Args:
        replay_buffer: buffer instance which learn is reading from

    Returns:
        A dictionary which includes useful metrics
    """
    batch_size = self._batch_size if not self.on_policy else len(replay_buffer)

    if len(replay_buffer) < batch_size or len(replay_buffer) == 0:
        return {}

    report = {}
    for _ in range(self._training_rounds):
        self._training_steps += 1
        batch = replay_buffer.sample(batch_size)
        single_report = {}
        if isinstance(batch, TransitionBatch):
            batch = self.preprocess_batch(batch)
            single_report = self.learn_batch(batch)

        for k, v in single_report.items():
            if k in report:
                report[k].append(v)
            else:
                report[k] = [v]
    return report
def learn_batch(self, batch: TransitionBatch) ‑> Dict[str, Any]

Args

batch
batch of data that agent is learning from

Returns

A dictionary which includes useful metrics

Expand source code
@abstractmethod
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
    """
    Args:
        batch: batch of data that agent is learning from

    Returns:
        A dictionary which includes useful metrics
    """
    raise NotImplementedError("learn_batch is not implemented")
def preprocess_batch(self, batch: TransitionBatch) ‑> TransitionBatch

Processes a batch of transitions before passing it to learn_batch(). This function can be used to implement preprocessing steps such as transform the actions.

Expand source code
def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch:
    """
    Processes a batch of transitions before passing it to learn_batch().
    This function can be used to implement preprocessing steps such as
    transform the actions.
    """
    batch.state = self._history_summarization_module(batch.state)
    with torch.no_grad():
        batch.next_state = self._history_summarization_module(batch.next_state)

    batch.action = self._action_representation_module(batch.action)
    if batch.next_action is not None:
        batch.next_action = self._action_representation_module(batch.next_action)
    if batch.curr_available_actions is not None:
        batch.curr_available_actions = self._action_representation_module(
            batch.curr_available_actions
        )
    if batch.next_available_actions is not None:
        batch.next_available_actions = self._action_representation_module(
            batch.next_available_actions
        )
    return batch
def reset(self, action_space: ActionSpace) ‑> None

Resets policy maker for a new episode. Default implementation does nothing.

Expand source code
def reset(self, action_space: ActionSpace) -> None:
    """Resets policy maker for a new episode. Default implementation does nothing."""
    pass
def set_history_summarization_module(self, value: HistorySummarizationModule) ‑> None
Expand source code
def set_history_summarization_module(
    self, value: HistorySummarizationModule
) -> None:
    self._history_summarization_module = value