Module pearl.action_representation_modules.binary_action_representation_module
Expand source code
import torch
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
)
class BinaryActionTensorRepresentationModule(ActionRepresentationModule):
"""
Transform index to its binary representation.
"""
def __init__(self, bits_num: int) -> None:
super(BinaryActionTensorRepresentationModule, self).__init__()
self._bits_num = bits_num
self._max_number_actions: int = 2**bits_num
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.binary(x)
# (batch_size x action_dim)
def binary(self, x: torch.Tensor) -> torch.Tensor:
mask = 2 ** torch.arange(self._bits_num).to(device=x.device)
x = x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
return x.to(dtype=torch.float32)
@property
def max_number_actions(self) -> int:
return self._max_number_actions
@property
def representation_dim(self) -> int:
return self._bits_num
Classes
class BinaryActionTensorRepresentationModule (bits_num: int)
-
Transform index to its binary representation.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class BinaryActionTensorRepresentationModule(ActionRepresentationModule): """ Transform index to its binary representation. """ def __init__(self, bits_num: int) -> None: super(BinaryActionTensorRepresentationModule, self).__init__() self._bits_num = bits_num self._max_number_actions: int = 2**bits_num def forward(self, x: torch.Tensor) -> torch.Tensor: return self.binary(x) # (batch_size x action_dim) def binary(self, x: torch.Tensor) -> torch.Tensor: mask = 2 ** torch.arange(self._bits_num).to(device=x.device) x = x.unsqueeze(-1).bitwise_and(mask).ne(0).byte() return x.to(dtype=torch.float32) @property def max_number_actions(self) -> int: return self._max_number_actions @property def representation_dim(self) -> int: return self._bits_num
Ancestors
- ActionRepresentationModule
- abc.ABC
- torch.nn.modules.module.Module
Instance variables
var max_number_actions : int
-
Expand source code
@property def max_number_actions(self) -> int: return self._max_number_actions
var representation_dim : int
-
Expand source code
@property def representation_dim(self) -> int: return self._bits_num
Methods
def binary(self, x: torch.Tensor) ‑> torch.Tensor
-
Expand source code
def binary(self, x: torch.Tensor) -> torch.Tensor: mask = 2 ** torch.arange(self._bits_num).to(device=x.device) x = x.unsqueeze(-1).bitwise_and(mask).ne(0).byte() return x.to(dtype=torch.float32)
Inherited members