Module pearl.history_summarization_modules.history_summarization_module
Expand source code
from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.nn as nn
from pearl.api.action import Action
from pearl.api.history import History
from pearl.api.observation import Observation
from pearl.api.state import SubjectiveState
class HistorySummarizationModule(ABC, nn.Module):
"""
An abstract interface for exploration module.
"""
@abstractmethod
def summarize_history(
self, observation: Observation, action: Optional[Action]
) -> SubjectiveState:
pass
@abstractmethod
def get_history(self) -> History:
pass
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
@abstractmethod
def reset(self) -> None:
pass
Classes
class HistorySummarizationModule (*args, **kwargs)
-
An abstract interface for exploration module.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class HistorySummarizationModule(ABC, nn.Module): """ An abstract interface for exploration module. """ @abstractmethod def summarize_history( self, observation: Observation, action: Optional[Action] ) -> SubjectiveState: pass @abstractmethod def get_history(self) -> History: pass @abstractmethod def forward(self, x: torch.Tensor) -> torch.Tensor: pass @abstractmethod def reset(self) -> None: pass
Ancestors
- abc.ABC
- torch.nn.modules.module.Module
Subclasses
- IdentityHistorySummarizationModule
- LSTMHistorySummarizationModule
- StackingHistorySummarizationModule
Methods
def forward(self, x: torch.Tensor) ‑> torch.Tensor
-
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
@abstractmethod def forward(self, x: torch.Tensor) -> torch.Tensor: pass
def get_history(self) ‑> object
-
Expand source code
@abstractmethod def get_history(self) -> History: pass
def reset(self) ‑> None
-
Expand source code
@abstractmethod def reset(self) -> None: pass
def summarize_history(self, observation: object, action: Optional[torch.Tensor]) ‑> torch.Tensor
-
Expand source code
@abstractmethod def summarize_history( self, observation: Observation, action: Optional[Action] ) -> SubjectiveState: pass