Module pearl.safety_modules.safety_module

Expand source code
from abc import ABC, abstractmethod

from pearl.api.action_space import ActionSpace
from pearl.history_summarization_modules.history_summarization_module import (
    SubjectiveState,
)
from pearl.policy_learners.policy_learner import PolicyLearner

from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.transition import TransitionBatch


class SafetyModule(ABC):
    """
    An abstract interface for exploration module.
    """

    @abstractmethod
    def filter_action(
        self, subjective_state: SubjectiveState, action_space: ActionSpace
    ) -> ActionSpace:
        pass

    @abstractmethod
    def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None:
        pass

    @abstractmethod
    def learn_batch(self, batch: TransitionBatch) -> None:
        pass

Classes

class SafetyModule

An abstract interface for exploration module.

Expand source code
class SafetyModule(ABC):
    """
    An abstract interface for exploration module.
    """

    @abstractmethod
    def filter_action(
        self, subjective_state: SubjectiveState, action_space: ActionSpace
    ) -> ActionSpace:
        pass

    @abstractmethod
    def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None:
        pass

    @abstractmethod
    def learn_batch(self, batch: TransitionBatch) -> None:
        pass

Ancestors

  • abc.ABC

Subclasses

Methods

def filter_action(self, subjective_state: torch.Tensor, action_space: ActionSpace) ‑> ActionSpace
Expand source code
@abstractmethod
def filter_action(
    self, subjective_state: SubjectiveState, action_space: ActionSpace
) -> ActionSpace:
    pass
def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) ‑> None
Expand source code
@abstractmethod
def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None:
    pass
def learn_batch(self, batch: TransitionBatch) ‑> None
Expand source code
@abstractmethod
def learn_batch(self, batch: TransitionBatch) -> None:
    pass