Module pearl.replay_buffers.sequential_decision_making.bootstrap_replay_buffer
Expand source code
import random
from typing import Optional
import torch
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.reward import Reward
from pearl.api.state import SubjectiveState
from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import ( # noqa E501
FIFOOffPolicyReplayBuffer,
)
from pearl.replay_buffers.transition import (
TransitionWithBootstrapMask,
TransitionWithBootstrapMaskBatch,
)
class BootstrapReplayBuffer(FIFOOffPolicyReplayBuffer):
r"""A ensemble replay buffer that supports the implementation of the
masking distribution used in Bootstrapped DQN, as described in [1]. This
implementation uses a Bernoulli(p) masking distribution (see Appendix 3.1
of [1]). The `k`-th Q-network receives an independently drawn mask
`w_k ~ Bernoulli(p)` for each piece of experience, and `w_k = 1` means
the experience is included in the training data.
[1] Ian Osband, Charles Blundell, Alexander Pritzel, and Benjamin
Van Roy, Deep exploration via bootstrapped DQN. Advances in Neural
Information Processing Systems, 2016. https://arxiv.org/abs/1602.04621.
Args:
capacity: Size of the replay buffer.
p: The parameter of the Bernoulli masking distribution.
ensemble_size: The number of Q-networks in the ensemble.
has_next_state: Whether each piece of experience includes the next state.
has_next_action: Whether each piece of experience includes the next action.
has_next_available:actions: Whether each piece of experience includes the
next available actions.
"""
def __init__(
self,
capacity: int,
p: float,
ensemble_size: int,
) -> None:
super().__init__(capacity=capacity)
self.p = p
self.ensemble_size = ensemble_size
def push(
self,
state: SubjectiveState,
action: Action,
reward: Reward,
next_state: SubjectiveState,
curr_available_actions: ActionSpace,
next_available_actions: ActionSpace,
done: bool,
max_number_actions: Optional[int] = None,
cost: Optional[float] = None,
) -> None:
# sample the bootstrap mask from Bernoulli(p) on each push
probs = torch.tensor(self.p).repeat(1, self.ensemble_size)
bootstrap_mask = torch.bernoulli(probs)
(
curr_available_actions_tensor_with_padding,
curr_unavailable_actions_mask,
) = self._create_action_tensor_and_mask(
max_number_actions, curr_available_actions
)
(
next_available_actions_tensor_with_padding,
next_unavailable_actions_mask,
) = self._create_action_tensor_and_mask(
max_number_actions, next_available_actions
)
self.memory.append(
TransitionWithBootstrapMask(
state=self._process_single_state(state),
action=self._process_single_action(action),
reward=self._process_single_reward(reward),
next_state=self._process_single_state(next_state),
curr_available_actions=curr_available_actions_tensor_with_padding,
curr_unavailable_actions_mask=curr_unavailable_actions_mask,
next_available_actions=next_available_actions_tensor_with_padding,
next_unavailable_actions_mask=next_unavailable_actions_mask,
done=self._process_single_done(done),
cost=self._process_single_cost(cost),
bootstrap_mask=bootstrap_mask,
)
)
def sample(self, batch_size: int) -> TransitionWithBootstrapMaskBatch:
if batch_size > len(self):
raise ValueError(
f"Can't get a batch of size {batch_size} from a "
f"replay buffer with only {len(self)} elements"
)
samples = random.sample(self.memory, batch_size)
transition_batch = self._create_transition_batch(
transitions=samples,
has_next_state=self._has_next_state,
has_next_action=self._has_next_action,
is_action_continuous=self.is_action_continuous,
has_next_available_actions=self._has_next_available_actions,
has_cost_available=self.has_cost_available,
)
bootstrap_mask_batch = torch.cat([x.bootstrap_mask for x in samples])
return TransitionWithBootstrapMaskBatch(
state=transition_batch.state,
action=transition_batch.action,
reward=transition_batch.reward,
next_state=transition_batch.next_state,
curr_available_actions=transition_batch.curr_available_actions,
curr_unavailable_actions_mask=transition_batch.curr_unavailable_actions_mask,
next_available_actions=transition_batch.next_available_actions,
next_unavailable_actions_mask=transition_batch.next_unavailable_actions_mask,
done=transition_batch.done,
bootstrap_mask=bootstrap_mask_batch,
)
Classes
class BootstrapReplayBuffer (capacity: int, p: float, ensemble_size: int)
-
A ensemble replay buffer that supports the implementation of the masking distribution used in Bootstrapped DQN, as described in [1]. This implementation uses a Bernoulli(p) masking distribution (see Appendix 3.1 of [1]). The
k
-th Q-network receives an independently drawn maskw_k ~ Bernoulli(p)
for each piece of experience, andw_k = 1
means the experience is included in the training data.[1] Ian Osband, Charles Blundell, Alexander Pritzel, and Benjamin Van Roy, Deep exploration via bootstrapped DQN. Advances in Neural Information Processing Systems, 2016. https://arxiv.org/abs/1602.04621.
Args
capacity
- Size of the replay buffer.
p
- The parameter of the Bernoulli masking distribution.
ensemble_size
- The number of Q-networks in the ensemble.
has_next_state
- Whether each piece of experience includes the next state.
has_next_action
- Whether each piece of experience includes the next action.
has_next_available:actions: Whether each piece of experience includes the next available actions.
Expand source code
class BootstrapReplayBuffer(FIFOOffPolicyReplayBuffer): r"""A ensemble replay buffer that supports the implementation of the masking distribution used in Bootstrapped DQN, as described in [1]. This implementation uses a Bernoulli(p) masking distribution (see Appendix 3.1 of [1]). The `k`-th Q-network receives an independently drawn mask `w_k ~ Bernoulli(p)` for each piece of experience, and `w_k = 1` means the experience is included in the training data. [1] Ian Osband, Charles Blundell, Alexander Pritzel, and Benjamin Van Roy, Deep exploration via bootstrapped DQN. Advances in Neural Information Processing Systems, 2016. https://arxiv.org/abs/1602.04621. Args: capacity: Size of the replay buffer. p: The parameter of the Bernoulli masking distribution. ensemble_size: The number of Q-networks in the ensemble. has_next_state: Whether each piece of experience includes the next state. has_next_action: Whether each piece of experience includes the next action. has_next_available:actions: Whether each piece of experience includes the next available actions. """ def __init__( self, capacity: int, p: float, ensemble_size: int, ) -> None: super().__init__(capacity=capacity) self.p = p self.ensemble_size = ensemble_size def push( self, state: SubjectiveState, action: Action, reward: Reward, next_state: SubjectiveState, curr_available_actions: ActionSpace, next_available_actions: ActionSpace, done: bool, max_number_actions: Optional[int] = None, cost: Optional[float] = None, ) -> None: # sample the bootstrap mask from Bernoulli(p) on each push probs = torch.tensor(self.p).repeat(1, self.ensemble_size) bootstrap_mask = torch.bernoulli(probs) ( curr_available_actions_tensor_with_padding, curr_unavailable_actions_mask, ) = self._create_action_tensor_and_mask( max_number_actions, curr_available_actions ) ( next_available_actions_tensor_with_padding, next_unavailable_actions_mask, ) = self._create_action_tensor_and_mask( max_number_actions, next_available_actions ) self.memory.append( TransitionWithBootstrapMask( state=self._process_single_state(state), action=self._process_single_action(action), reward=self._process_single_reward(reward), next_state=self._process_single_state(next_state), curr_available_actions=curr_available_actions_tensor_with_padding, curr_unavailable_actions_mask=curr_unavailable_actions_mask, next_available_actions=next_available_actions_tensor_with_padding, next_unavailable_actions_mask=next_unavailable_actions_mask, done=self._process_single_done(done), cost=self._process_single_cost(cost), bootstrap_mask=bootstrap_mask, ) ) def sample(self, batch_size: int) -> TransitionWithBootstrapMaskBatch: if batch_size > len(self): raise ValueError( f"Can't get a batch of size {batch_size} from a " f"replay buffer with only {len(self)} elements" ) samples = random.sample(self.memory, batch_size) transition_batch = self._create_transition_batch( transitions=samples, has_next_state=self._has_next_state, has_next_action=self._has_next_action, is_action_continuous=self.is_action_continuous, has_next_available_actions=self._has_next_available_actions, has_cost_available=self.has_cost_available, ) bootstrap_mask_batch = torch.cat([x.bootstrap_mask for x in samples]) return TransitionWithBootstrapMaskBatch( state=transition_batch.state, action=transition_batch.action, reward=transition_batch.reward, next_state=transition_batch.next_state, curr_available_actions=transition_batch.curr_available_actions, curr_unavailable_actions_mask=transition_batch.curr_unavailable_actions_mask, next_available_actions=transition_batch.next_available_actions, next_unavailable_actions_mask=transition_batch.next_unavailable_actions_mask, done=transition_batch.done, bootstrap_mask=bootstrap_mask_batch, )
Ancestors
Inherited members