Module pearl.policy_learners.contextual_bandits.contextual_bandit_base
Expand source code
from abc import abstractmethod
from typing import Any, Dict, Optional
import torch
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.reward import Value
from pearl.history_summarization_modules.history_summarization_module import (
    SubjectiveState,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
    ExplorationModule,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
DEFAULT_ACTION_SPACE = DiscreteActionSpace([torch.tensor([0])])
class ContextualBanditBase(PolicyLearner):
    """
    A base class for Contextual Bandit policy learner.
    """
    def __init__(
        self,
        feature_dim: int,
        exploration_module: ExplorationModule,
        training_rounds: int = 100,
        batch_size: int = 128,
    ) -> None:
        super(ContextualBanditBase, self).__init__(
            training_rounds=training_rounds,
            batch_size=batch_size,
            exploration_module=exploration_module,
            on_policy=False,
            is_action_continuous=False,  # TODO change in subclasses when we add CB for continuous
        )
        self._feature_dim = feature_dim
    @property
    def feature_dim(self) -> int:
        return self._feature_dim
    @abstractmethod
    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        pass
    @abstractmethod
    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: ActionSpace,
        action_availability_mask: Optional[torch.Tensor] = None,
        exploit: bool = False,
    ) -> Action:
        pass
    @abstractmethod
    def get_scores(
        self,
        subjective_state: SubjectiveState,
    ) -> Value:
        """
        Returns:
            Return scores trained by this contextual bandit algorithm
        """
        pass
Classes
class ContextualBanditBase (feature_dim: int, exploration_module: ExplorationModule, training_rounds: int = 100, batch_size: int = 128)- 
A base class for Contextual Bandit policy learner.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class ContextualBanditBase(PolicyLearner): """ A base class for Contextual Bandit policy learner. """ def __init__( self, feature_dim: int, exploration_module: ExplorationModule, training_rounds: int = 100, batch_size: int = 128, ) -> None: super(ContextualBanditBase, self).__init__( training_rounds=training_rounds, batch_size=batch_size, exploration_module=exploration_module, on_policy=False, is_action_continuous=False, # TODO change in subclasses when we add CB for continuous ) self._feature_dim = feature_dim @property def feature_dim(self) -> int: return self._feature_dim @abstractmethod def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: pass @abstractmethod def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, action_availability_mask: Optional[torch.Tensor] = None, exploit: bool = False, ) -> Action: pass @abstractmethod def get_scores( self, subjective_state: SubjectiveState, ) -> Value: """ Returns: Return scores trained by this contextual bandit algorithm """ passAncestors
- PolicyLearner
 - torch.nn.modules.module.Module
 - abc.ABC
 
Subclasses
Instance variables
var feature_dim : int- 
Expand source code
@property def feature_dim(self) -> int: return self._feature_dim 
Methods
def act(self, subjective_state: torch.Tensor, available_action_space: ActionSpace, action_availability_mask: Optional[torch.Tensor] = None, exploit: bool = False) ‑> torch.Tensor- 
Expand source code
@abstractmethod def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, action_availability_mask: Optional[torch.Tensor] = None, exploit: bool = False, ) -> Action: pass def get_scores(self, subjective_state: torch.Tensor) ‑> object- 
Returns
Return scores trained by this contextual bandit algorithm
Expand source code
@abstractmethod def get_scores( self, subjective_state: SubjectiveState, ) -> Value: """ Returns: Return scores trained by this contextual bandit algorithm """ pass 
Inherited members