Module pearl.replay_buffers.transition
Expand source code
import dataclasses
from dataclasses import dataclass
from typing import Optional, TypeVar
import torch
from torch import Tensor
T = TypeVar("T", bound="Transition")
@dataclass(frozen=False)
class Transition:
"""
Transition is designed for one single set of data
"""
state: torch.Tensor
action: torch.Tensor
reward: torch.Tensor
done: torch.Tensor = torch.tensor(True) # default True is useful for bandits
next_state: Optional[torch.Tensor] = None
next_action: Optional[torch.Tensor] = None
curr_available_actions: Optional[torch.Tensor] = None
curr_unavailable_actions_mask: Optional[torch.Tensor] = None
next_available_actions: Optional[torch.Tensor] = None
next_unavailable_actions_mask: Optional[torch.Tensor] = None
weight: Optional[torch.Tensor] = None
cum_reward: Optional[torch.Tensor] = None
cost: Optional[torch.Tensor] = None
def to(self: T, device: torch.device) -> T:
# iterate over all fields, move to correct device
for f in dataclasses.fields(self.__class__):
if getattr(self, f.name) is not None:
super().__setattr__(
f.name,
torch.as_tensor(getattr(self, f.name)).to(device),
)
return self
@property
def device(self) -> torch.device:
return self.state.device
TB = TypeVar("TB", bound="TransitionBatch")
@dataclass(frozen=False)
class TransitionBatch:
"""
TransitionBatch is designed for data batch
"""
state: torch.Tensor
action: torch.Tensor
reward: torch.Tensor
done: torch.Tensor = torch.tensor(True) # default True is useful for bandits
next_state: Optional[torch.Tensor] = None
next_action: Optional[torch.Tensor] = None
curr_available_actions: Optional[torch.Tensor] = None
curr_unavailable_actions_mask: Optional[torch.Tensor] = None
next_available_actions: Optional[torch.Tensor] = None
next_unavailable_actions_mask: Optional[torch.Tensor] = None
weight: Optional[torch.Tensor] = None
cum_reward: Optional[torch.Tensor] = None
time_diff: Optional[torch.Tensor] = None
cost: Optional[torch.Tensor] = None
def to(self: TB, device: torch.device) -> TB:
# iterate over all fields
for f in dataclasses.fields(self.__class__):
if getattr(self, f.name) is not None:
item = getattr(self, f.name)
item = torch.as_tensor(item, device=device)
super().__setattr__(
f.name,
item,
)
return self
@property
def device(self) -> torch.device:
"""
The device where the batch lives.
"""
return self.state.device
def __len__(self) -> int:
return self.reward.shape[0]
@dataclass(frozen=False)
class TransitionWithBootstrapMask(Transition):
bootstrap_mask: Optional[torch.Tensor] = None
@dataclass(frozen=False)
class TransitionWithBootstrapMaskBatch(TransitionBatch):
bootstrap_mask: Optional[torch.Tensor] = None
def filter_batch_by_bootstrap_mask(
batch: TransitionWithBootstrapMaskBatch, z: Tensor
) -> TransitionBatch:
r"""A helper function that filters a `TransitionBatch` to only those transitions
that are marked as active (by its `bootstrap_mask` field) for a given ensemble
index `z`.
Args:
batch: The original `TransitionWithBootstrapMask`.
z: The ensemble index to filter on.
Returns:
A filtered `TransitionBatch`.
"""
mask: Optional[torch.Tensor] = batch.bootstrap_mask
def _filter_tensor(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if x is None or mask is None:
return None
return x[mask[:, z] == 1]
filtered_state = _filter_tensor(batch.state)
filtered_action = _filter_tensor(batch.action)
filtered_reward = _filter_tensor(batch.reward)
filtered_done = _filter_tensor(batch.done)
assert filtered_state is not None
assert filtered_action is not None
assert filtered_reward is not None
assert filtered_done is not None
return TransitionBatch(
state=filtered_state,
action=filtered_action,
reward=filtered_reward,
done=filtered_done,
next_state=_filter_tensor(batch.next_state),
next_action=_filter_tensor(batch.next_action),
curr_available_actions=_filter_tensor(batch.curr_available_actions),
curr_unavailable_actions_mask=_filter_tensor(
batch.curr_unavailable_actions_mask
),
next_available_actions=_filter_tensor(batch.next_available_actions),
next_unavailable_actions_mask=_filter_tensor(
batch.next_unavailable_actions_mask
),
weight=_filter_tensor(batch.weight),
cum_reward=_filter_tensor(batch.cum_reward),
cost=_filter_tensor(batch.cost),
).to(batch.device)
Functions
def filter_batch_by_bootstrap_mask(batch: TransitionWithBootstrapMaskBatch, z: torch.Tensor) ‑> TransitionBatch
-
A helper function that filters a
TransitionBatch
to only those transitions that are marked as active (by itsbootstrap_mask
field) for a given ensemble indexz
.Args
batch
- The original
TransitionWithBootstrapMask
. z
- The ensemble index to filter on.
Returns
A filtered
TransitionBatch
.Expand source code
def filter_batch_by_bootstrap_mask( batch: TransitionWithBootstrapMaskBatch, z: Tensor ) -> TransitionBatch: r"""A helper function that filters a `TransitionBatch` to only those transitions that are marked as active (by its `bootstrap_mask` field) for a given ensemble index `z`. Args: batch: The original `TransitionWithBootstrapMask`. z: The ensemble index to filter on. Returns: A filtered `TransitionBatch`. """ mask: Optional[torch.Tensor] = batch.bootstrap_mask def _filter_tensor(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: if x is None or mask is None: return None return x[mask[:, z] == 1] filtered_state = _filter_tensor(batch.state) filtered_action = _filter_tensor(batch.action) filtered_reward = _filter_tensor(batch.reward) filtered_done = _filter_tensor(batch.done) assert filtered_state is not None assert filtered_action is not None assert filtered_reward is not None assert filtered_done is not None return TransitionBatch( state=filtered_state, action=filtered_action, reward=filtered_reward, done=filtered_done, next_state=_filter_tensor(batch.next_state), next_action=_filter_tensor(batch.next_action), curr_available_actions=_filter_tensor(batch.curr_available_actions), curr_unavailable_actions_mask=_filter_tensor( batch.curr_unavailable_actions_mask ), next_available_actions=_filter_tensor(batch.next_available_actions), next_unavailable_actions_mask=_filter_tensor( batch.next_unavailable_actions_mask ), weight=_filter_tensor(batch.weight), cum_reward=_filter_tensor(batch.cum_reward), cost=_filter_tensor(batch.cost), ).to(batch.device)
Classes
class Transition (state: torch.Tensor, action: torch.Tensor, reward: torch.Tensor, done: torch.Tensor = tensor(True), next_state: Optional[torch.Tensor] = None, next_action: Optional[torch.Tensor] = None, curr_available_actions: Optional[torch.Tensor] = None, curr_unavailable_actions_mask: Optional[torch.Tensor] = None, next_available_actions: Optional[torch.Tensor] = None, next_unavailable_actions_mask: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None, cum_reward: Optional[torch.Tensor] = None, cost: Optional[torch.Tensor] = None)
-
Transition is designed for one single set of data
Expand source code
@dataclass(frozen=False) class Transition: """ Transition is designed for one single set of data """ state: torch.Tensor action: torch.Tensor reward: torch.Tensor done: torch.Tensor = torch.tensor(True) # default True is useful for bandits next_state: Optional[torch.Tensor] = None next_action: Optional[torch.Tensor] = None curr_available_actions: Optional[torch.Tensor] = None curr_unavailable_actions_mask: Optional[torch.Tensor] = None next_available_actions: Optional[torch.Tensor] = None next_unavailable_actions_mask: Optional[torch.Tensor] = None weight: Optional[torch.Tensor] = None cum_reward: Optional[torch.Tensor] = None cost: Optional[torch.Tensor] = None def to(self: T, device: torch.device) -> T: # iterate over all fields, move to correct device for f in dataclasses.fields(self.__class__): if getattr(self, f.name) is not None: super().__setattr__( f.name, torch.as_tensor(getattr(self, f.name)).to(device), ) return self @property def device(self) -> torch.device: return self.state.device
Subclasses
Class variables
var action : torch.Tensor
var cost : Optional[torch.Tensor]
var cum_reward : Optional[torch.Tensor]
var curr_available_actions : Optional[torch.Tensor]
var done : torch.Tensor
var next_action : Optional[torch.Tensor]
var next_available_actions : Optional[torch.Tensor]
var next_state : Optional[torch.Tensor]
var reward : torch.Tensor
var state : torch.Tensor
var weight : Optional[torch.Tensor]
Instance variables
var device : torch.device
-
Expand source code
@property def device(self) -> torch.device: return self.state.device
Methods
def to(self: ~T, device: torch.device) ‑> ~T
-
Expand source code
def to(self: T, device: torch.device) -> T: # iterate over all fields, move to correct device for f in dataclasses.fields(self.__class__): if getattr(self, f.name) is not None: super().__setattr__( f.name, torch.as_tensor(getattr(self, f.name)).to(device), ) return self
class TransitionBatch (state: torch.Tensor, action: torch.Tensor, reward: torch.Tensor, done: torch.Tensor = tensor(True), next_state: Optional[torch.Tensor] = None, next_action: Optional[torch.Tensor] = None, curr_available_actions: Optional[torch.Tensor] = None, curr_unavailable_actions_mask: Optional[torch.Tensor] = None, next_available_actions: Optional[torch.Tensor] = None, next_unavailable_actions_mask: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None, cum_reward: Optional[torch.Tensor] = None, time_diff: Optional[torch.Tensor] = None, cost: Optional[torch.Tensor] = None)
-
TransitionBatch is designed for data batch
Expand source code
@dataclass(frozen=False) class TransitionBatch: """ TransitionBatch is designed for data batch """ state: torch.Tensor action: torch.Tensor reward: torch.Tensor done: torch.Tensor = torch.tensor(True) # default True is useful for bandits next_state: Optional[torch.Tensor] = None next_action: Optional[torch.Tensor] = None curr_available_actions: Optional[torch.Tensor] = None curr_unavailable_actions_mask: Optional[torch.Tensor] = None next_available_actions: Optional[torch.Tensor] = None next_unavailable_actions_mask: Optional[torch.Tensor] = None weight: Optional[torch.Tensor] = None cum_reward: Optional[torch.Tensor] = None time_diff: Optional[torch.Tensor] = None cost: Optional[torch.Tensor] = None def to(self: TB, device: torch.device) -> TB: # iterate over all fields for f in dataclasses.fields(self.__class__): if getattr(self, f.name) is not None: item = getattr(self, f.name) item = torch.as_tensor(item, device=device) super().__setattr__( f.name, item, ) return self @property def device(self) -> torch.device: """ The device where the batch lives. """ return self.state.device def __len__(self) -> int: return self.reward.shape[0]
Subclasses
Class variables
var action : torch.Tensor
var cost : Optional[torch.Tensor]
var cum_reward : Optional[torch.Tensor]
var curr_available_actions : Optional[torch.Tensor]
var done : torch.Tensor
var next_action : Optional[torch.Tensor]
var next_available_actions : Optional[torch.Tensor]
var next_state : Optional[torch.Tensor]
var reward : torch.Tensor
var state : torch.Tensor
var time_diff : Optional[torch.Tensor]
var weight : Optional[torch.Tensor]
Instance variables
var device : torch.device
-
The device where the batch lives.
Expand source code
@property def device(self) -> torch.device: """ The device where the batch lives. """ return self.state.device
Methods
def to(self: ~TB, device: torch.device) ‑> ~TB
-
Expand source code
def to(self: TB, device: torch.device) -> TB: # iterate over all fields for f in dataclasses.fields(self.__class__): if getattr(self, f.name) is not None: item = getattr(self, f.name) item = torch.as_tensor(item, device=device) super().__setattr__( f.name, item, ) return self
class TransitionWithBootstrapMask (state: torch.Tensor, action: torch.Tensor, reward: torch.Tensor, done: torch.Tensor = tensor(True), next_state: Optional[torch.Tensor] = None, next_action: Optional[torch.Tensor] = None, curr_available_actions: Optional[torch.Tensor] = None, curr_unavailable_actions_mask: Optional[torch.Tensor] = None, next_available_actions: Optional[torch.Tensor] = None, next_unavailable_actions_mask: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None, cum_reward: Optional[torch.Tensor] = None, cost: Optional[torch.Tensor] = None, bootstrap_mask: Optional[torch.Tensor] = None)
-
TransitionWithBootstrapMask(state: torch.Tensor, action: torch.Tensor, reward: torch.Tensor, done: torch.Tensor = tensor(True), next_state: Optional[torch.Tensor] = None, next_action: Optional[torch.Tensor] = None, curr_available_actions: Optional[torch.Tensor] = None, curr_unavailable_actions_mask: Optional[torch.Tensor] = None, next_available_actions: Optional[torch.Tensor] = None, next_unavailable_actions_mask: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None, cum_reward: Optional[torch.Tensor] = None, cost: Optional[torch.Tensor] = None, bootstrap_mask: Optional[torch.Tensor] = None)
Expand source code
@dataclass(frozen=False) class TransitionWithBootstrapMask(Transition): bootstrap_mask: Optional[torch.Tensor] = None
Ancestors
Class variables
var bootstrap_mask : Optional[torch.Tensor]
class TransitionWithBootstrapMaskBatch (state: torch.Tensor, action: torch.Tensor, reward: torch.Tensor, done: torch.Tensor = tensor(True), next_state: Optional[torch.Tensor] = None, next_action: Optional[torch.Tensor] = None, curr_available_actions: Optional[torch.Tensor] = None, curr_unavailable_actions_mask: Optional[torch.Tensor] = None, next_available_actions: Optional[torch.Tensor] = None, next_unavailable_actions_mask: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None, cum_reward: Optional[torch.Tensor] = None, time_diff: Optional[torch.Tensor] = None, cost: Optional[torch.Tensor] = None, bootstrap_mask: Optional[torch.Tensor] = None)
-
TransitionWithBootstrapMaskBatch(state: torch.Tensor, action: torch.Tensor, reward: torch.Tensor, done: torch.Tensor = tensor(True), next_state: Optional[torch.Tensor] = None, next_action: Optional[torch.Tensor] = None, curr_available_actions: Optional[torch.Tensor] = None, curr_unavailable_actions_mask: Optional[torch.Tensor] = None, next_available_actions: Optional[torch.Tensor] = None, next_unavailable_actions_mask: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None, cum_reward: Optional[torch.Tensor] = None, time_diff: Optional[torch.Tensor] = None, cost: Optional[torch.Tensor] = None, bootstrap_mask: Optional[torch.Tensor] = None)
Expand source code
@dataclass(frozen=False) class TransitionWithBootstrapMaskBatch(TransitionBatch): bootstrap_mask: Optional[torch.Tensor] = None
Ancestors
Class variables
var bootstrap_mask : Optional[torch.Tensor]
Inherited members