Module pearl.policy_learners.exploration_modules.contextual_bandits.ucb_exploration
Expand source code
from typing import Any, Optional
import torch
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.state import SubjectiveState
from pearl.policy_learners.exploration_modules.common.score_exploration_base import (
ScoreExplorationBase,
)
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
# TODO: Assumes discrete gym action space
class UCBExploration(ScoreExplorationBase):
"""
UCB exploration module.
"""
def __init__(self, alpha: float) -> None:
super(UCBExploration, self).__init__()
self._alpha = alpha
def sigma(
self,
subjective_state: SubjectiveState,
representation: torch.nn.Module,
) -> torch.Tensor:
"""
Args:
subjective_state: feature vector (either state,
or state and action features after concatenation)
Shape should be either (batch_size, action_count, feature_dim) or
(batch_size, feature_dim).
Returns:
sigma with shape (batch_size, action_count) or (batch_size, 1)
"""
sigma = representation.calculate_sigma(subjective_state)
nan_check = torch.isnan(sigma)
sigma = torch.where(nan_check, torch.zeros_like(sigma), sigma)
return sigma
def get_scores(
self,
subjective_state: SubjectiveState,
values: torch.Tensor,
action_space: ActionSpace,
representation: Optional[torch.nn.Module] = None,
exploit_action: Optional[Action] = None,
) -> torch.Tensor:
"""
Args:
subjective_state is in shape of (batch_size, feature_size)
values is in shape of (batch_size, action_count)
Returns:
return shape(batch_size, action_count)
or
Args:
subjective_state is in shape of (feature_size)
values is in shape of (action_count)
Returns:
return shape(action_count)
"""
assert isinstance(action_space, DiscreteActionSpace)
action_count = action_space.n
values = values.view(-1, action_count) # (batch_size, action_count)
sigma = self.sigma(
subjective_state=subjective_state,
# pyre-fixme[6]: For 2nd argument expected `Module` but got
# `Optional[Module]`.
representation=representation,
)
# a safe check before reshape sigma into values
sigma = sigma.view(values.shape)
ucb_scores = values + self._alpha * sigma
return ucb_scores.view(-1, action_space.n) # batch_size, action_count
class DisjointUCBExploration(UCBExploration):
"""
Same as UCBExploration, but with a separate bandit model for each action
"""
# pyre-fixme[14]: `sigma` overrides method defined in `UCBExploration`
# inconsistently.
def sigma(
self,
subjective_state: SubjectiveState,
representation: torch.nn.ModuleList,
) -> torch.Tensor:
"""
Args:
subjective_state: this is feature vector in shape, batch_size, action_count, feature
representation: a list of bandit models, one per action (arm)
"""
sigma = []
for i, arm_model in enumerate(representation):
sigma.append(
super(DisjointUCBExploration, self).sigma(
subjective_state=subjective_state[:, i, :],
representation=arm_model,
)
)
sigma = torch.stack(sigma)
# change from shape(action_count, batch_size) to shape(batch_size, action_count)
sigma = sigma.permute(1, 0)
return sigma
class VanillaUCBExploration(UCBExploration):
"""
Vanilla UCB exploration module with counter.
"""
def __init__(self) -> None:
super(VanillaUCBExploration, self).__init__(alpha=1)
# pyre-fixme[4]: Attribute must be annotated.
self.action_execution_count = {}
# pyre-fixme[4]: Attribute must be annotated.
self.action_executed = torch.tensor(1)
# pyre-fixme[14]: `sigma` overrides method defined in `UCBExploration`
# inconsistently.
def sigma(
self,
subjective_state: SubjectiveState,
action_space: ActionSpace,
representation: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
assert isinstance(action_space, DiscreteActionSpace)
exploration_bonus = torch.zeros((action_space.n)) # (action_space_size)
for action in action_space.actions:
if action not in self.action_execution_count:
self.action_execution_count[action] = 1
exploration_bonus[action] = torch.sqrt(
torch.log(self.action_executed) / self.action_execution_count[action]
)
return exploration_bonus
# TODO: We should make discrete action space itself iterable
# pyre-fixme[14]: `act` overrides method defined in `ScoreExplorationBase`
# inconsistently.
def act(
self,
subjective_state: SubjectiveState,
action_space: ActionSpace,
values: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
representation: Any = None,
exploit_action: Optional[Action] = None,
) -> Action:
selected_action = super().act(
subjective_state,
action_space,
values,
representation,
exploit_action,
)
self.action_execution_count[selected_action] += 1
self.action_executed += 1
return selected_action
Classes
class DisjointUCBExploration (alpha: float)
-
Same as UCBExploration, but with a separate bandit model for each action
Expand source code
class DisjointUCBExploration(UCBExploration): """ Same as UCBExploration, but with a separate bandit model for each action """ # pyre-fixme[14]: `sigma` overrides method defined in `UCBExploration` # inconsistently. def sigma( self, subjective_state: SubjectiveState, representation: torch.nn.ModuleList, ) -> torch.Tensor: """ Args: subjective_state: this is feature vector in shape, batch_size, action_count, feature representation: a list of bandit models, one per action (arm) """ sigma = [] for i, arm_model in enumerate(representation): sigma.append( super(DisjointUCBExploration, self).sigma( subjective_state=subjective_state[:, i, :], representation=arm_model, ) ) sigma = torch.stack(sigma) # change from shape(action_count, batch_size) to shape(batch_size, action_count) sigma = sigma.permute(1, 0) return sigma
Ancestors
Methods
def sigma(self, subjective_state: torch.Tensor, representation: torch.nn.modules.container.ModuleList) ‑> torch.Tensor
-
Args
subjective_state
- this is feature vector in shape, batch_size, action_count, feature
representation
- a list of bandit models, one per action (arm)
Expand source code
def sigma( self, subjective_state: SubjectiveState, representation: torch.nn.ModuleList, ) -> torch.Tensor: """ Args: subjective_state: this is feature vector in shape, batch_size, action_count, feature representation: a list of bandit models, one per action (arm) """ sigma = [] for i, arm_model in enumerate(representation): sigma.append( super(DisjointUCBExploration, self).sigma( subjective_state=subjective_state[:, i, :], representation=arm_model, ) ) sigma = torch.stack(sigma) # change from shape(action_count, batch_size) to shape(batch_size, action_count) sigma = sigma.permute(1, 0) return sigma
Inherited members
class UCBExploration (alpha: float)
-
UCB exploration module.
Expand source code
class UCBExploration(ScoreExplorationBase): """ UCB exploration module. """ def __init__(self, alpha: float) -> None: super(UCBExploration, self).__init__() self._alpha = alpha def sigma( self, subjective_state: SubjectiveState, representation: torch.nn.Module, ) -> torch.Tensor: """ Args: subjective_state: feature vector (either state, or state and action features after concatenation) Shape should be either (batch_size, action_count, feature_dim) or (batch_size, feature_dim). Returns: sigma with shape (batch_size, action_count) or (batch_size, 1) """ sigma = representation.calculate_sigma(subjective_state) nan_check = torch.isnan(sigma) sigma = torch.where(nan_check, torch.zeros_like(sigma), sigma) return sigma def get_scores( self, subjective_state: SubjectiveState, values: torch.Tensor, action_space: ActionSpace, representation: Optional[torch.nn.Module] = None, exploit_action: Optional[Action] = None, ) -> torch.Tensor: """ Args: subjective_state is in shape of (batch_size, feature_size) values is in shape of (batch_size, action_count) Returns: return shape(batch_size, action_count) or Args: subjective_state is in shape of (feature_size) values is in shape of (action_count) Returns: return shape(action_count) """ assert isinstance(action_space, DiscreteActionSpace) action_count = action_space.n values = values.view(-1, action_count) # (batch_size, action_count) sigma = self.sigma( subjective_state=subjective_state, # pyre-fixme[6]: For 2nd argument expected `Module` but got # `Optional[Module]`. representation=representation, ) # a safe check before reshape sigma into values sigma = sigma.view(values.shape) ucb_scores = values + self._alpha * sigma return ucb_scores.view(-1, action_space.n) # batch_size, action_count
Ancestors
Subclasses
Methods
def get_scores(self, subjective_state: torch.Tensor, values: torch.Tensor, action_space: ActionSpace, representation: Optional[torch.nn.modules.module.Module] = None, exploit_action: Optional[torch.Tensor] = None) ‑> torch.Tensor
-
Args
subjective_state is in shape of (batch_size, feature_size) values is in shape of (batch_size, action_count)
Returns
return shape(batch_size, action_count) or
Args
subjective_state is in shape of (feature_size) values is in shape of (action_count)
Returns
return shape(action_count)
Expand source code
def get_scores( self, subjective_state: SubjectiveState, values: torch.Tensor, action_space: ActionSpace, representation: Optional[torch.nn.Module] = None, exploit_action: Optional[Action] = None, ) -> torch.Tensor: """ Args: subjective_state is in shape of (batch_size, feature_size) values is in shape of (batch_size, action_count) Returns: return shape(batch_size, action_count) or Args: subjective_state is in shape of (feature_size) values is in shape of (action_count) Returns: return shape(action_count) """ assert isinstance(action_space, DiscreteActionSpace) action_count = action_space.n values = values.view(-1, action_count) # (batch_size, action_count) sigma = self.sigma( subjective_state=subjective_state, # pyre-fixme[6]: For 2nd argument expected `Module` but got # `Optional[Module]`. representation=representation, ) # a safe check before reshape sigma into values sigma = sigma.view(values.shape) ucb_scores = values + self._alpha * sigma return ucb_scores.view(-1, action_space.n) # batch_size, action_count
def sigma(self, subjective_state: torch.Tensor, representation: torch.nn.modules.module.Module) ‑> torch.Tensor
-
Args
subjective_state
- feature vector (either state,
or state and action features after concatenation) Shape should be either (batch_size, action_count, feature_dim) or (batch_size, feature_dim).
Returns
sigma with shape (batch_size, action_count) or (batch_size, 1)
Expand source code
def sigma( self, subjective_state: SubjectiveState, representation: torch.nn.Module, ) -> torch.Tensor: """ Args: subjective_state: feature vector (either state, or state and action features after concatenation) Shape should be either (batch_size, action_count, feature_dim) or (batch_size, feature_dim). Returns: sigma with shape (batch_size, action_count) or (batch_size, 1) """ sigma = representation.calculate_sigma(subjective_state) nan_check = torch.isnan(sigma) sigma = torch.where(nan_check, torch.zeros_like(sigma), sigma) return sigma
Inherited members
class VanillaUCBExploration
-
Vanilla UCB exploration module with counter.
Expand source code
class VanillaUCBExploration(UCBExploration): """ Vanilla UCB exploration module with counter. """ def __init__(self) -> None: super(VanillaUCBExploration, self).__init__(alpha=1) # pyre-fixme[4]: Attribute must be annotated. self.action_execution_count = {} # pyre-fixme[4]: Attribute must be annotated. self.action_executed = torch.tensor(1) # pyre-fixme[14]: `sigma` overrides method defined in `UCBExploration` # inconsistently. def sigma( self, subjective_state: SubjectiveState, action_space: ActionSpace, representation: Optional[torch.nn.Module] = None, ) -> torch.Tensor: assert isinstance(action_space, DiscreteActionSpace) exploration_bonus = torch.zeros((action_space.n)) # (action_space_size) for action in action_space.actions: if action not in self.action_execution_count: self.action_execution_count[action] = 1 exploration_bonus[action] = torch.sqrt( torch.log(self.action_executed) / self.action_execution_count[action] ) return exploration_bonus # TODO: We should make discrete action space itself iterable # pyre-fixme[14]: `act` overrides method defined in `ScoreExplorationBase` # inconsistently. def act( self, subjective_state: SubjectiveState, action_space: ActionSpace, values: torch.Tensor, # pyre-fixme[2]: Parameter annotation cannot be `Any`. representation: Any = None, exploit_action: Optional[Action] = None, ) -> Action: selected_action = super().act( subjective_state, action_space, values, representation, exploit_action, ) self.action_execution_count[selected_action] += 1 self.action_executed += 1 return selected_action
Ancestors
Inherited members