Module pearl.action_representation_modules.identity_action_representation_module

Expand source code
import torch
from pearl.action_representation_modules.action_representation_module import (
    ActionRepresentationModule,
)


class IdentityActionRepresentationModule(ActionRepresentationModule):
    """
    An trivial class that outputs actions identitically as input.
    """

    def __init__(
        self, max_number_actions: int = -1, representation_dim: int = -1
    ) -> None:
        super(IdentityActionRepresentationModule, self).__init__()
        self._max_number_actions = max_number_actions
        self._representation_dim = representation_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x

    @property
    def max_number_actions(self) -> int:
        return self._max_number_actions

    @property
    def representation_dim(self) -> int:
        return self._representation_dim

Classes

class IdentityActionRepresentationModule (max_number_actions: int = -1, representation_dim: int = -1)

An trivial class that outputs actions identitically as input.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class IdentityActionRepresentationModule(ActionRepresentationModule):
    """
    An trivial class that outputs actions identitically as input.
    """

    def __init__(
        self, max_number_actions: int = -1, representation_dim: int = -1
    ) -> None:
        super(IdentityActionRepresentationModule, self).__init__()
        self._max_number_actions = max_number_actions
        self._representation_dim = representation_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x

    @property
    def max_number_actions(self) -> int:
        return self._max_number_actions

    @property
    def representation_dim(self) -> int:
        return self._representation_dim

Ancestors

Instance variables

var max_number_actions : int
Expand source code
@property
def max_number_actions(self) -> int:
    return self._max_number_actions
var representation_dim : int
Expand source code
@property
def representation_dim(self) -> int:
    return self._representation_dim

Inherited members