Module pearl.policy_learners.exploration_modules.sequential_decision_making.deep_exploration
Expand source code
from typing import Optional
import torch
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.state import SubjectiveState
from pearl.neural_networks.common.value_networks import EnsembleQValueNetwork
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from torch.nn import functional as F
class DeepExploration(ExplorationModule):
r"""An exploration strategy that follows a policy based on a randomly
drawn value function (from its posterior distribution), an idea that was
developed in [1, 2, 3]. The implementation is based off of [3] and uses an
ensemble of Q-value functions.
[1] Ian Osband, Daniel Russo, and Benjamin Van Roy, (More) efficient reinforcement
learning via posterior sampling. Advances in Neural Information Processing
Systems, 2013. https://arxiv.org/abs/1306.0940.
[2] Ian Osband, Benjamin Van Roy, Daniel Russo, and Zheng Wen, Deep exploration
via randomized value functions. Journal of Machine Learning Research, 2019.
https://arxiv.org/abs/1703.07608.
[3] Ian Osband, Charles Blundell, Alexander Pritzel, and Benjamin
Vay Roy, Deep exploration via bootstrapped DQN. Advances in Neural
Information Processing Systems, 2016. https://arxiv.org/abs/1602.04621.
Args:
q_ensemble_network (EnsembleQValueNetwork): A network that outputs
a tensor of shape (num_samples, num_actions) where each row is
the Q-value of taking each possible action.
"""
def __init__(
self,
q_ensemble_network: EnsembleQValueNetwork,
) -> None:
super(DeepExploration, self).__init__()
self.q_ensemble_network = q_ensemble_network
def act(
self,
subjective_state: SubjectiveState,
action_space: ActionSpace,
exploit_action: Optional[Action] = None,
values: Optional[torch.Tensor] = None,
action_availability_mask: Optional[torch.Tensor] = None,
representation: Optional[torch.nn.Module] = None,
) -> Action:
assert isinstance(action_space, DiscreteActionSpace)
states_repeated = torch.repeat_interleave(
subjective_state.unsqueeze(0), action_space.n, dim=0
)
# (action_space_size x state_dim)
actions = F.one_hot(torch.arange(0, action_space.n)).to(subjective_state.device)
# (action_space_size, action_dim)
with torch.no_grad():
q_values = self.q_ensemble_network.get_q_values(
state_batch=states_repeated, action_batch=actions, persistent=True
)
# this does a forward pass since all available
# actions are already stacked together
return torch.argmax(q_values).view((-1))
def reset(self) -> None: # noqa: B027
# sample a new epistemic index (i.e., a Q-network) at the beginning of a
# new episode for temporally consistent exploration
self.q_ensemble_network.resample_epistemic_index()
Classes
class DeepExploration (q_ensemble_network: EnsembleQValueNetwork)
-
An exploration strategy that follows a policy based on a randomly drawn value function (from its posterior distribution), an idea that was developed in [1, 2, 3]. The implementation is based off of [3] and uses an ensemble of Q-value functions.
[1] Ian Osband, Daniel Russo, and Benjamin Van Roy, (More) efficient reinforcement learning via posterior sampling. Advances in Neural Information Processing Systems, 2013. https://arxiv.org/abs/1306.0940. [2] Ian Osband, Benjamin Van Roy, Daniel Russo, and Zheng Wen, Deep exploration via randomized value functions. Journal of Machine Learning Research, 2019. https://arxiv.org/abs/1703.07608. [3] Ian Osband, Charles Blundell, Alexander Pritzel, and Benjamin Vay Roy, Deep exploration via bootstrapped DQN. Advances in Neural Information Processing Systems, 2016. https://arxiv.org/abs/1602.04621.
Args: q_ensemble_network (EnsembleQValueNetwork): A network that outputs a tensor of shape (num_samples, num_actions) where each row is the Q-value of taking each possible action.
Expand source code
class DeepExploration(ExplorationModule): r"""An exploration strategy that follows a policy based on a randomly drawn value function (from its posterior distribution), an idea that was developed in [1, 2, 3]. The implementation is based off of [3] and uses an ensemble of Q-value functions. [1] Ian Osband, Daniel Russo, and Benjamin Van Roy, (More) efficient reinforcement learning via posterior sampling. Advances in Neural Information Processing Systems, 2013. https://arxiv.org/abs/1306.0940. [2] Ian Osband, Benjamin Van Roy, Daniel Russo, and Zheng Wen, Deep exploration via randomized value functions. Journal of Machine Learning Research, 2019. https://arxiv.org/abs/1703.07608. [3] Ian Osband, Charles Blundell, Alexander Pritzel, and Benjamin Vay Roy, Deep exploration via bootstrapped DQN. Advances in Neural Information Processing Systems, 2016. https://arxiv.org/abs/1602.04621. Args: q_ensemble_network (EnsembleQValueNetwork): A network that outputs a tensor of shape (num_samples, num_actions) where each row is the Q-value of taking each possible action. """ def __init__( self, q_ensemble_network: EnsembleQValueNetwork, ) -> None: super(DeepExploration, self).__init__() self.q_ensemble_network = q_ensemble_network def act( self, subjective_state: SubjectiveState, action_space: ActionSpace, exploit_action: Optional[Action] = None, values: Optional[torch.Tensor] = None, action_availability_mask: Optional[torch.Tensor] = None, representation: Optional[torch.nn.Module] = None, ) -> Action: assert isinstance(action_space, DiscreteActionSpace) states_repeated = torch.repeat_interleave( subjective_state.unsqueeze(0), action_space.n, dim=0 ) # (action_space_size x state_dim) actions = F.one_hot(torch.arange(0, action_space.n)).to(subjective_state.device) # (action_space_size, action_dim) with torch.no_grad(): q_values = self.q_ensemble_network.get_q_values( state_batch=states_repeated, action_batch=actions, persistent=True ) # this does a forward pass since all available # actions are already stacked together return torch.argmax(q_values).view((-1)) def reset(self) -> None: # noqa: B027 # sample a new epistemic index (i.e., a Q-network) at the beginning of a # new episode for temporally consistent exploration self.q_ensemble_network.resample_epistemic_index()
Ancestors
- ExplorationModule
- abc.ABC
Methods
def act(self, subjective_state: torch.Tensor, action_space: ActionSpace, exploit_action: Optional[torch.Tensor] = None, values: Optional[torch.Tensor] = None, action_availability_mask: Optional[torch.Tensor] = None, representation: Optional[torch.nn.modules.module.Module] = None) ‑> torch.Tensor
-
Expand source code
def act( self, subjective_state: SubjectiveState, action_space: ActionSpace, exploit_action: Optional[Action] = None, values: Optional[torch.Tensor] = None, action_availability_mask: Optional[torch.Tensor] = None, representation: Optional[torch.nn.Module] = None, ) -> Action: assert isinstance(action_space, DiscreteActionSpace) states_repeated = torch.repeat_interleave( subjective_state.unsqueeze(0), action_space.n, dim=0 ) # (action_space_size x state_dim) actions = F.one_hot(torch.arange(0, action_space.n)).to(subjective_state.device) # (action_space_size, action_dim) with torch.no_grad(): q_values = self.q_ensemble_network.get_q_values( state_batch=states_repeated, action_batch=actions, persistent=True ) # this does a forward pass since all available # actions are already stacked together return torch.argmax(q_values).view((-1))
Inherited members