Module pearl.policy_learners.contextual_bandits.contextual_bandit_base
Expand source code
from abc import abstractmethod
from typing import Any, Dict, Optional
import torch
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.reward import Value
from pearl.history_summarization_modules.history_summarization_module import (
SubjectiveState,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
DEFAULT_ACTION_SPACE = DiscreteActionSpace([torch.tensor([0])])
class ContextualBanditBase(PolicyLearner):
"""
A base class for Contextual Bandit policy learner.
"""
def __init__(
self,
feature_dim: int,
exploration_module: ExplorationModule,
training_rounds: int = 100,
batch_size: int = 128,
) -> None:
super(ContextualBanditBase, self).__init__(
training_rounds=training_rounds,
batch_size=batch_size,
exploration_module=exploration_module,
on_policy=False,
is_action_continuous=False, # TODO change in subclasses when we add CB for continuous
)
self._feature_dim = feature_dim
@property
def feature_dim(self) -> int:
return self._feature_dim
@abstractmethod
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
pass
@abstractmethod
def act(
self,
subjective_state: SubjectiveState,
available_action_space: ActionSpace,
action_availability_mask: Optional[torch.Tensor] = None,
exploit: bool = False,
) -> Action:
pass
@abstractmethod
def get_scores(
self,
subjective_state: SubjectiveState,
) -> Value:
"""
Returns:
Return scores trained by this contextual bandit algorithm
"""
pass
Classes
class ContextualBanditBase (feature_dim: int, exploration_module: ExplorationModule, training_rounds: int = 100, batch_size: int = 128)
-
A base class for Contextual Bandit policy learner.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class ContextualBanditBase(PolicyLearner): """ A base class for Contextual Bandit policy learner. """ def __init__( self, feature_dim: int, exploration_module: ExplorationModule, training_rounds: int = 100, batch_size: int = 128, ) -> None: super(ContextualBanditBase, self).__init__( training_rounds=training_rounds, batch_size=batch_size, exploration_module=exploration_module, on_policy=False, is_action_continuous=False, # TODO change in subclasses when we add CB for continuous ) self._feature_dim = feature_dim @property def feature_dim(self) -> int: return self._feature_dim @abstractmethod def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: pass @abstractmethod def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, action_availability_mask: Optional[torch.Tensor] = None, exploit: bool = False, ) -> Action: pass @abstractmethod def get_scores( self, subjective_state: SubjectiveState, ) -> Value: """ Returns: Return scores trained by this contextual bandit algorithm """ pass
Ancestors
- PolicyLearner
- torch.nn.modules.module.Module
- abc.ABC
Subclasses
Instance variables
var feature_dim : int
-
Expand source code
@property def feature_dim(self) -> int: return self._feature_dim
Methods
def act(self, subjective_state: torch.Tensor, available_action_space: ActionSpace, action_availability_mask: Optional[torch.Tensor] = None, exploit: bool = False) ‑> torch.Tensor
-
Expand source code
@abstractmethod def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, action_availability_mask: Optional[torch.Tensor] = None, exploit: bool = False, ) -> Action: pass
def get_scores(self, subjective_state: torch.Tensor) ‑> object
-
Returns
Return scores trained by this contextual bandit algorithm
Expand source code
@abstractmethod def get_scores( self, subjective_state: SubjectiveState, ) -> Value: """ Returns: Return scores trained by this contextual bandit algorithm """ pass
Inherited members