Module pearl.replay_buffers.contextual_bandits.discrete_contextual_bandit_replay_buffer
Expand source code
import random
from typing import Optional
import torch
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.reward import Reward
from pearl.api.state import SubjectiveState
from pearl.replay_buffers.tensor_based_replay_buffer import TensorBasedReplayBuffer
from pearl.replay_buffers.transition import Transition, TransitionBatch
from pearl.utils.tensor_like import assert_is_tensor_like
class DiscreteContextualBanditReplayBuffer(TensorBasedReplayBuffer):
"""
DiscreteContextualBanditReplayBuffer has the following key differences
from other replay buffers
- No next action or next state related
- action is action idx instead of action value
- done is not needed, as for contextual bandit, it is always True
"""
def __init__(self, capacity: int) -> None:
super(DiscreteContextualBanditReplayBuffer, self).__init__(
capacity=capacity,
has_next_state=False,
has_next_action=False,
has_next_available_actions=False,
)
def push(
self,
state: SubjectiveState,
action: Action,
reward: Reward,
next_state: SubjectiveState,
curr_available_actions: ActionSpace,
next_available_actions: ActionSpace,
done: bool,
max_number_actions: Optional[int] = None,
cost: Optional[float] = None,
) -> None:
# signature of push is the same as others, in order to match codes in PearlAgent
# TODO add curr_available_actions and curr_available_actions_mask if needed in the future
action = assert_is_tensor_like(action)
self.memory.append(
Transition(
state=self._process_single_state(state),
action=action,
reward=self._process_single_reward(reward),
).to(self.device)
)
def sample(self, batch_size: int) -> TransitionBatch:
samples = random.sample(self.memory, batch_size)
return TransitionBatch(
state=torch.cat([x.state for x in samples]),
action=torch.stack([x.action for x in samples]),
reward=torch.cat([x.reward for x in samples]),
).to(self.device)
Classes
class DiscreteContextualBanditReplayBuffer (capacity: int)
-
DiscreteContextualBanditReplayBuffer has the following key differences from other replay buffers - No next action or next state related - action is action idx instead of action value - done is not needed, as for contextual bandit, it is always True
Expand source code
class DiscreteContextualBanditReplayBuffer(TensorBasedReplayBuffer): """ DiscreteContextualBanditReplayBuffer has the following key differences from other replay buffers - No next action or next state related - action is action idx instead of action value - done is not needed, as for contextual bandit, it is always True """ def __init__(self, capacity: int) -> None: super(DiscreteContextualBanditReplayBuffer, self).__init__( capacity=capacity, has_next_state=False, has_next_action=False, has_next_available_actions=False, ) def push( self, state: SubjectiveState, action: Action, reward: Reward, next_state: SubjectiveState, curr_available_actions: ActionSpace, next_available_actions: ActionSpace, done: bool, max_number_actions: Optional[int] = None, cost: Optional[float] = None, ) -> None: # signature of push is the same as others, in order to match codes in PearlAgent # TODO add curr_available_actions and curr_available_actions_mask if needed in the future action = assert_is_tensor_like(action) self.memory.append( Transition( state=self._process_single_state(state), action=action, reward=self._process_single_reward(reward), ).to(self.device) ) def sample(self, batch_size: int) -> TransitionBatch: samples = random.sample(self.memory, batch_size) return TransitionBatch( state=torch.cat([x.state for x in samples]), action=torch.stack([x.action for x in samples]), reward=torch.cat([x.reward for x in samples]), ).to(self.device)
Ancestors
- TensorBasedReplayBuffer
- ReplayBuffer
- abc.ABC
Inherited members