Module pearl.safety_modules.identity_safety_module

Expand source code
from typing import Optional

from pearl.api.action_space import ActionSpace
from pearl.history_summarization_modules.history_summarization_module import (
    SubjectiveState,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.transition import TransitionBatch
from pearl.safety_modules.safety_module import SafetyModule


class IdentitySafetyModule(SafetyModule):
    """
    A safety module that does not restrict action spaces.
    """

    def filter_action(
        self, subjective_state: SubjectiveState, action_space: ActionSpace
    ) -> ActionSpace:
        return action_space

    def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None:
        pass

    def learn_batch(self, batch: TransitionBatch) -> None:
        pass

Classes

class IdentitySafetyModule

A safety module that does not restrict action spaces.

Expand source code
class IdentitySafetyModule(SafetyModule):
    """
    A safety module that does not restrict action spaces.
    """

    def filter_action(
        self, subjective_state: SubjectiveState, action_space: ActionSpace
    ) -> ActionSpace:
        return action_space

    def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None:
        pass

    def learn_batch(self, batch: TransitionBatch) -> None:
        pass

Ancestors

Methods

def filter_action(self, subjective_state: torch.Tensor, action_space: ActionSpace) ‑> ActionSpace
Expand source code
def filter_action(
    self, subjective_state: SubjectiveState, action_space: ActionSpace
) -> ActionSpace:
    return action_space
def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) ‑> None
Expand source code
def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None:
    pass
def learn_batch(self, batch: TransitionBatch) ‑> None
Expand source code
def learn_batch(self, batch: TransitionBatch) -> None:
    pass