Module pearl.action_representation_modules.one_hot_action_representation_module

Expand source code
import torch
import torch.nn.functional as F

from pearl.action_representation_modules.action_representation_module import (
    ActionRepresentationModule,
)


class OneHotActionTensorRepresentationModule(ActionRepresentationModule):
    """
    An one-hot action representation module.
    """

    def __init__(self, max_number_actions: int) -> None:
        super(OneHotActionTensorRepresentationModule, self).__init__()
        self._max_number_actions = max_number_actions

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.one_hot(x.long(), num_classes=self._max_number_actions).squeeze(dim=-2)
        # (batch_size x action_dim)

    @property
    def max_number_actions(self) -> int:
        return self._max_number_actions

    @property
    def representation_dim(self) -> int:
        return self._max_number_actions

Classes

class OneHotActionTensorRepresentationModule (max_number_actions: int)

An one-hot action representation module.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class OneHotActionTensorRepresentationModule(ActionRepresentationModule):
    """
    An one-hot action representation module.
    """

    def __init__(self, max_number_actions: int) -> None:
        super(OneHotActionTensorRepresentationModule, self).__init__()
        self._max_number_actions = max_number_actions

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.one_hot(x.long(), num_classes=self._max_number_actions).squeeze(dim=-2)
        # (batch_size x action_dim)

    @property
    def max_number_actions(self) -> int:
        return self._max_number_actions

    @property
    def representation_dim(self) -> int:
        return self._max_number_actions

Ancestors

Instance variables

var max_number_actions : int
Expand source code
@property
def max_number_actions(self) -> int:
    return self._max_number_actions
var representation_dim : int
Expand source code
@property
def representation_dim(self) -> int:
    return self._max_number_actions

Inherited members