Module pearl.history_summarization_modules.identity_history_summarization_module

Expand source code
from typing import Optional

import torch

from pearl.api.action import Action
from pearl.api.history import History
from pearl.api.observation import Observation
from pearl.history_summarization_modules.history_summarization_module import (
    HistorySummarizationModule,
    SubjectiveState,
)


class IdentityHistorySummarizationModule(HistorySummarizationModule):
    """
    A history summarization module that simply uses the original observations.
    """

    def __init__(self) -> None:
        super(IdentityHistorySummarizationModule, self).__init__()
        self.history: History = None

    def summarize_history(
        self, observation: Observation, action: Optional[Action]
    ) -> SubjectiveState:
        self.history = observation
        # pyre-fixme[7]: incompatible return type
        # Due to currently incorrect assumption that SubjectiveState
        # is always a Tensor (not the case for tabular Q-learning, for example)
        return observation

    def get_history(self) -> History:
        return self.history

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x

    def reset(self) -> None:
        self.history = None

Classes

class IdentityHistorySummarizationModule

A history summarization module that simply uses the original observations.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class IdentityHistorySummarizationModule(HistorySummarizationModule):
    """
    A history summarization module that simply uses the original observations.
    """

    def __init__(self) -> None:
        super(IdentityHistorySummarizationModule, self).__init__()
        self.history: History = None

    def summarize_history(
        self, observation: Observation, action: Optional[Action]
    ) -> SubjectiveState:
        self.history = observation
        # pyre-fixme[7]: incompatible return type
        # Due to currently incorrect assumption that SubjectiveState
        # is always a Tensor (not the case for tabular Q-learning, for example)
        return observation

    def get_history(self) -> History:
        return self.history

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x

    def reset(self) -> None:
        self.history = None

Ancestors

Methods

def get_history(self) ‑> object
Expand source code
def get_history(self) -> History:
    return self.history
def reset(self) ‑> None
Expand source code
def reset(self) -> None:
    self.history = None
def summarize_history(self, observation: object, action: Optional[torch.Tensor]) ‑> torch.Tensor
Expand source code
def summarize_history(
    self, observation: Observation, action: Optional[Action]
) -> SubjectiveState:
    self.history = observation
    # pyre-fixme[7]: incompatible return type
    # Due to currently incorrect assumption that SubjectiveState
    # is always a Tensor (not the case for tabular Q-learning, for example)
    return observation

Inherited members