Module pearl.policy_learners.contextual_bandits.disjoint_bandit
Expand source code
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, Dict, List, Optional
import torch
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.history_summarization_modules.history_summarization_module import (
SubjectiveState,
)
from pearl.neural_networks.common.utils import ensemble_forward
from pearl.policy_learners.contextual_bandits.contextual_bandit_base import (
ContextualBanditBase,
)
from pearl.policy_learners.exploration_modules.common.score_exploration_base import (
ScoreExplorationBase,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.action_utils import (
concatenate_actions_to_state,
)
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
class DisjointBanditContainer(ContextualBanditBase):
"""
Wrapper for disjoint models with discrete (and usually small) action space.
Each action has its own bandit model (can be based on UCB or Thompson Sampling).
Using the Composite design pattern:
https://refactoring.guru/design-patterns/composite
"""
def __init__(
self,
feature_dim: int,
arm_bandits: List[ContextualBanditBase],
exploration_module: ExplorationModule,
training_rounds: int = 100,
batch_size: int = 128,
state_features_only: bool = False,
) -> None:
super(DisjointBanditContainer, self).__init__(
feature_dim=feature_dim,
exploration_module=exploration_module,
training_rounds=training_rounds,
batch_size=batch_size,
)
# Currently our disjoint LinUCB usecase only use LinearRegression
self._arm_bandits: torch.nn.ModuleList = torch.nn.ModuleList(arm_bandits)
self._n_arms: int = len(arm_bandits)
self._state_features_only = state_features_only
@property
def n_arms(self) -> int:
return self._n_arms
def _validate_batch(self, batch: TransitionBatch) -> None:
assert (
batch.action.dtype == torch.long
), "action must be torch.long type (index of arm)"
assert batch.action.min().item() >= 0, "action must be >= 0"
assert (
batch.action.max().item() < self._n_arms
), "action must be < number of arms"
def _partition_batch_by_arm(self, batch: TransitionBatch) -> List[TransitionBatch]:
"""
Break input batch down into per-arm batches based on action
"""
batches = []
for arm in range(self.n_arms):
# mask of observations for this arm
# assume action indices
mask = batch.action[:, 0] == arm
if batch.state.ndim == 2:
# shape: (batch_size, feature_size)
# same features for all arms
state = batch.state
elif batch.state.ndim == 3:
# shape: (batch_size, num_arms, feature_size)
# different features for each arm
assert (
batch.state.shape[1] == self.n_arms
), "For 3D state, 2nd dimension must be equal to number of arms"
state = batch.state[:, arm, :]
batches.append(
TransitionBatch(
state=state[mask],
reward=batch.reward[mask],
weight=batch.weight[mask]
if batch.weight is not None
else torch.ones_like(mask, dtype=torch.float),
# empty action features since disjoint model used
# action as index of per-arm model
# if arms need different features, use 3D `state` instead
action=torch.empty(
int(mask.sum().item()),
0,
dtype=torch.float,
device=batch.device,
),
).to(batch.device)
)
return batches
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
"""
action_idx determines which of the models the observation will be routed to.
"""
self._validate_batch(batch)
arm_batches = self._partition_batch_by_arm(batch)
returns = {}
for i, (arm_bandit, arm_batch) in enumerate(
zip(self._arm_bandits, arm_batches)
):
if len(arm_batch) == 0:
# skip updates if batch has no observations for this arm
continue
returns.update(
{
f"arm_{i}_{k}": v
for k, v in arm_bandit.learn_batch(arm_batch).items()
}
)
return returns
@property
def models(self) -> List[torch.nn.Module]:
"""
Get a list of models of each bandit
"""
return [bandit.model for bandit in self._arm_bandits]
def act(
self,
subjective_state: SubjectiveState,
available_action_space: ActionSpace,
action_availability_mask: Optional[torch.Tensor] = None,
exploit: bool = False,
) -> Action:
assert isinstance(available_action_space, DiscreteActionSpace)
# (batch_size, action_count, feature_size)
feature = concatenate_actions_to_state(
subjective_state=subjective_state,
action_space=available_action_space,
state_features_only=self._state_features_only,
)
# (batch_size, action_count, feature_size)
values = ensemble_forward(self.models, feature, use_for_loop=True)
return self._exploration_module.act(
subjective_state=feature,
action_space=available_action_space,
values=values,
# pyre-fixme[6]: In call `ExplorationModule.act`, for argument
# `representation`, expected `Optional[Module]` but got `List[Module]`.
representation=self.models,
action_availability_mask=action_availability_mask,
)
def get_scores(
self,
subjective_state: SubjectiveState,
action_space: DiscreteActionSpace,
) -> torch.Tensor:
"""
Returns:
UCB scores when exploration module is UCB
Shape is (batch, num_arms) or (num_arms,)
"""
exploration_module = self._exploration_module
assert isinstance(exploration_module, ScoreExplorationBase)
feature = concatenate_actions_to_state(
subjective_state=subjective_state,
action_space=action_space,
state_features_only=self._state_features_only,
)
# (batch_size, action_count, feature_size)
return exploration_module.get_scores(
subjective_state=feature,
values=ensemble_forward(self.models, feature, use_for_loop=True),
action_space=action_space,
representation=self.models, # pyre-fixme[6]: unexpected type
).squeeze()
@property
def optimizer(self) -> torch.optim.Optimizer:
return self._optimizer
Classes
class DisjointBanditContainer (feature_dim: int, arm_bandits: List[ContextualBanditBase], exploration_module: ExplorationModule, training_rounds: int = 100, batch_size: int = 128, state_features_only: bool = False)
-
Wrapper for disjoint models with discrete (and usually small) action space. Each action has its own bandit model (can be based on UCB or Thompson Sampling). Using the Composite design pattern: https://refactoring.guru/design-patterns/composite
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class DisjointBanditContainer(ContextualBanditBase): """ Wrapper for disjoint models with discrete (and usually small) action space. Each action has its own bandit model (can be based on UCB or Thompson Sampling). Using the Composite design pattern: https://refactoring.guru/design-patterns/composite """ def __init__( self, feature_dim: int, arm_bandits: List[ContextualBanditBase], exploration_module: ExplorationModule, training_rounds: int = 100, batch_size: int = 128, state_features_only: bool = False, ) -> None: super(DisjointBanditContainer, self).__init__( feature_dim=feature_dim, exploration_module=exploration_module, training_rounds=training_rounds, batch_size=batch_size, ) # Currently our disjoint LinUCB usecase only use LinearRegression self._arm_bandits: torch.nn.ModuleList = torch.nn.ModuleList(arm_bandits) self._n_arms: int = len(arm_bandits) self._state_features_only = state_features_only @property def n_arms(self) -> int: return self._n_arms def _validate_batch(self, batch: TransitionBatch) -> None: assert ( batch.action.dtype == torch.long ), "action must be torch.long type (index of arm)" assert batch.action.min().item() >= 0, "action must be >= 0" assert ( batch.action.max().item() < self._n_arms ), "action must be < number of arms" def _partition_batch_by_arm(self, batch: TransitionBatch) -> List[TransitionBatch]: """ Break input batch down into per-arm batches based on action """ batches = [] for arm in range(self.n_arms): # mask of observations for this arm # assume action indices mask = batch.action[:, 0] == arm if batch.state.ndim == 2: # shape: (batch_size, feature_size) # same features for all arms state = batch.state elif batch.state.ndim == 3: # shape: (batch_size, num_arms, feature_size) # different features for each arm assert ( batch.state.shape[1] == self.n_arms ), "For 3D state, 2nd dimension must be equal to number of arms" state = batch.state[:, arm, :] batches.append( TransitionBatch( state=state[mask], reward=batch.reward[mask], weight=batch.weight[mask] if batch.weight is not None else torch.ones_like(mask, dtype=torch.float), # empty action features since disjoint model used # action as index of per-arm model # if arms need different features, use 3D `state` instead action=torch.empty( int(mask.sum().item()), 0, dtype=torch.float, device=batch.device, ), ).to(batch.device) ) return batches def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: """ action_idx determines which of the models the observation will be routed to. """ self._validate_batch(batch) arm_batches = self._partition_batch_by_arm(batch) returns = {} for i, (arm_bandit, arm_batch) in enumerate( zip(self._arm_bandits, arm_batches) ): if len(arm_batch) == 0: # skip updates if batch has no observations for this arm continue returns.update( { f"arm_{i}_{k}": v for k, v in arm_bandit.learn_batch(arm_batch).items() } ) return returns @property def models(self) -> List[torch.nn.Module]: """ Get a list of models of each bandit """ return [bandit.model for bandit in self._arm_bandits] def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, action_availability_mask: Optional[torch.Tensor] = None, exploit: bool = False, ) -> Action: assert isinstance(available_action_space, DiscreteActionSpace) # (batch_size, action_count, feature_size) feature = concatenate_actions_to_state( subjective_state=subjective_state, action_space=available_action_space, state_features_only=self._state_features_only, ) # (batch_size, action_count, feature_size) values = ensemble_forward(self.models, feature, use_for_loop=True) return self._exploration_module.act( subjective_state=feature, action_space=available_action_space, values=values, # pyre-fixme[6]: In call `ExplorationModule.act`, for argument # `representation`, expected `Optional[Module]` but got `List[Module]`. representation=self.models, action_availability_mask=action_availability_mask, ) def get_scores( self, subjective_state: SubjectiveState, action_space: DiscreteActionSpace, ) -> torch.Tensor: """ Returns: UCB scores when exploration module is UCB Shape is (batch, num_arms) or (num_arms,) """ exploration_module = self._exploration_module assert isinstance(exploration_module, ScoreExplorationBase) feature = concatenate_actions_to_state( subjective_state=subjective_state, action_space=action_space, state_features_only=self._state_features_only, ) # (batch_size, action_count, feature_size) return exploration_module.get_scores( subjective_state=feature, values=ensemble_forward(self.models, feature, use_for_loop=True), action_space=action_space, representation=self.models, # pyre-fixme[6]: unexpected type ).squeeze() @property def optimizer(self) -> torch.optim.Optimizer: return self._optimizer
Ancestors
- ContextualBanditBase
- PolicyLearner
- torch.nn.modules.module.Module
- abc.ABC
Instance variables
var models : List[torch.nn.modules.module.Module]
-
Get a list of models of each bandit
Expand source code
@property def models(self) -> List[torch.nn.Module]: """ Get a list of models of each bandit """ return [bandit.model for bandit in self._arm_bandits]
var n_arms : int
-
Expand source code
@property def n_arms(self) -> int: return self._n_arms
var optimizer : torch.optim.optimizer.Optimizer
-
Expand source code
@property def optimizer(self) -> torch.optim.Optimizer: return self._optimizer
Methods
def act(self, subjective_state: torch.Tensor, available_action_space: ActionSpace, action_availability_mask: Optional[torch.Tensor] = None, exploit: bool = False) ‑> torch.Tensor
-
Expand source code
def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, action_availability_mask: Optional[torch.Tensor] = None, exploit: bool = False, ) -> Action: assert isinstance(available_action_space, DiscreteActionSpace) # (batch_size, action_count, feature_size) feature = concatenate_actions_to_state( subjective_state=subjective_state, action_space=available_action_space, state_features_only=self._state_features_only, ) # (batch_size, action_count, feature_size) values = ensemble_forward(self.models, feature, use_for_loop=True) return self._exploration_module.act( subjective_state=feature, action_space=available_action_space, values=values, # pyre-fixme[6]: In call `ExplorationModule.act`, for argument # `representation`, expected `Optional[Module]` but got `List[Module]`. representation=self.models, action_availability_mask=action_availability_mask, )
def get_scores(self, subjective_state: torch.Tensor, action_space: DiscreteActionSpace) ‑> torch.Tensor
-
Returns
UCB scores when exploration module is UCB Shape is (batch, num_arms) or (num_arms,)
Expand source code
def get_scores( self, subjective_state: SubjectiveState, action_space: DiscreteActionSpace, ) -> torch.Tensor: """ Returns: UCB scores when exploration module is UCB Shape is (batch, num_arms) or (num_arms,) """ exploration_module = self._exploration_module assert isinstance(exploration_module, ScoreExplorationBase) feature = concatenate_actions_to_state( subjective_state=subjective_state, action_space=action_space, state_features_only=self._state_features_only, ) # (batch_size, action_count, feature_size) return exploration_module.get_scores( subjective_state=feature, values=ensemble_forward(self.models, feature, use_for_loop=True), action_space=action_space, representation=self.models, # pyre-fixme[6]: unexpected type ).squeeze()
def learn_batch(self, batch: TransitionBatch) ‑> Dict[str, Any]
-
action_idx determines which of the models the observation will be routed to.
Expand source code
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: """ action_idx determines which of the models the observation will be routed to. """ self._validate_batch(batch) arm_batches = self._partition_batch_by_arm(batch) returns = {} for i, (arm_bandit, arm_batch) in enumerate( zip(self._arm_bandits, arm_batches) ): if len(arm_batch) == 0: # skip updates if batch has no observations for this arm continue returns.update( { f"arm_{i}_{k}": v for k, v in arm_bandit.learn_batch(arm_batch).items() } ) return returns
Inherited members