Module pearl.action_representation_modules.binary_action_representation_module

Expand source code
import torch

from pearl.action_representation_modules.action_representation_module import (
    ActionRepresentationModule,
)


class BinaryActionTensorRepresentationModule(ActionRepresentationModule):
    """
    Transform index to its binary representation.
    """

    def __init__(self, bits_num: int) -> None:
        super(BinaryActionTensorRepresentationModule, self).__init__()
        self._bits_num = bits_num
        self._max_number_actions: int = 2**bits_num

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.binary(x)
        # (batch_size x action_dim)

    def binary(self, x: torch.Tensor) -> torch.Tensor:
        mask = 2 ** torch.arange(self._bits_num).to(device=x.device)
        x = x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
        return x.to(dtype=torch.float32)

    @property
    def max_number_actions(self) -> int:
        return self._max_number_actions

    @property
    def representation_dim(self) -> int:
        return self._bits_num

Classes

class BinaryActionTensorRepresentationModule (bits_num: int)

Transform index to its binary representation.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class BinaryActionTensorRepresentationModule(ActionRepresentationModule):
    """
    Transform index to its binary representation.
    """

    def __init__(self, bits_num: int) -> None:
        super(BinaryActionTensorRepresentationModule, self).__init__()
        self._bits_num = bits_num
        self._max_number_actions: int = 2**bits_num

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.binary(x)
        # (batch_size x action_dim)

    def binary(self, x: torch.Tensor) -> torch.Tensor:
        mask = 2 ** torch.arange(self._bits_num).to(device=x.device)
        x = x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
        return x.to(dtype=torch.float32)

    @property
    def max_number_actions(self) -> int:
        return self._max_number_actions

    @property
    def representation_dim(self) -> int:
        return self._bits_num

Ancestors

Instance variables

var max_number_actions : int
Expand source code
@property
def max_number_actions(self) -> int:
    return self._max_number_actions
var representation_dim : int
Expand source code
@property
def representation_dim(self) -> int:
    return self._bits_num

Methods

def binary(self, x: torch.Tensor) ‑> torch.Tensor
Expand source code
def binary(self, x: torch.Tensor) -> torch.Tensor:
    mask = 2 ** torch.arange(self._bits_num).to(device=x.device)
    x = x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
    return x.to(dtype=torch.float32)

Inherited members