Module pearl.action_representation_modules.binary_action_representation_module
Expand source code
import torch
from pearl.action_representation_modules.action_representation_module import (
    ActionRepresentationModule,
)
class BinaryActionTensorRepresentationModule(ActionRepresentationModule):
    """
    Transform index to its binary representation.
    """
    def __init__(self, bits_num: int) -> None:
        super(BinaryActionTensorRepresentationModule, self).__init__()
        self._bits_num = bits_num
        self._max_number_actions: int = 2**bits_num
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.binary(x)
        # (batch_size x action_dim)
    def binary(self, x: torch.Tensor) -> torch.Tensor:
        mask = 2 ** torch.arange(self._bits_num).to(device=x.device)
        x = x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
        return x.to(dtype=torch.float32)
    @property
    def max_number_actions(self) -> int:
        return self._max_number_actions
    @property
    def representation_dim(self) -> int:
        return self._bits_num
Classes
class BinaryActionTensorRepresentationModule (bits_num: int)- 
Transform index to its binary representation.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class BinaryActionTensorRepresentationModule(ActionRepresentationModule): """ Transform index to its binary representation. """ def __init__(self, bits_num: int) -> None: super(BinaryActionTensorRepresentationModule, self).__init__() self._bits_num = bits_num self._max_number_actions: int = 2**bits_num def forward(self, x: torch.Tensor) -> torch.Tensor: return self.binary(x) # (batch_size x action_dim) def binary(self, x: torch.Tensor) -> torch.Tensor: mask = 2 ** torch.arange(self._bits_num).to(device=x.device) x = x.unsqueeze(-1).bitwise_and(mask).ne(0).byte() return x.to(dtype=torch.float32) @property def max_number_actions(self) -> int: return self._max_number_actions @property def representation_dim(self) -> int: return self._bits_numAncestors
- ActionRepresentationModule
 - abc.ABC
 - torch.nn.modules.module.Module
 
Instance variables
var max_number_actions : int- 
Expand source code
@property def max_number_actions(self) -> int: return self._max_number_actions var representation_dim : int- 
Expand source code
@property def representation_dim(self) -> int: return self._bits_num 
Methods
def binary(self, x: torch.Tensor) ‑> torch.Tensor- 
Expand source code
def binary(self, x: torch.Tensor) -> torch.Tensor: mask = 2 ** torch.arange(self._bits_num).to(device=x.device) x = x.unsqueeze(-1).bitwise_and(mask).ne(0).byte() return x.to(dtype=torch.float32) 
Inherited members