Module pearl.policy_learners.contextual_bandits.contextual_bandit_base

Expand source code
from abc import abstractmethod
from typing import Any, Dict, Optional

import torch

from pearl.api.action import Action

from pearl.api.action_space import ActionSpace
from pearl.api.reward import Value
from pearl.history_summarization_modules.history_summarization_module import (
    SubjectiveState,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
    ExplorationModule,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace

DEFAULT_ACTION_SPACE = DiscreteActionSpace([torch.tensor([0])])


class ContextualBanditBase(PolicyLearner):
    """
    A base class for Contextual Bandit policy learner.
    """

    def __init__(
        self,
        feature_dim: int,
        exploration_module: ExplorationModule,
        training_rounds: int = 100,
        batch_size: int = 128,
    ) -> None:
        super(ContextualBanditBase, self).__init__(
            training_rounds=training_rounds,
            batch_size=batch_size,
            exploration_module=exploration_module,
            on_policy=False,
            is_action_continuous=False,  # TODO change in subclasses when we add CB for continuous
        )
        self._feature_dim = feature_dim

    @property
    def feature_dim(self) -> int:
        return self._feature_dim

    @abstractmethod
    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        pass

    @abstractmethod
    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: ActionSpace,
        action_availability_mask: Optional[torch.Tensor] = None,
        exploit: bool = False,
    ) -> Action:
        pass

    @abstractmethod
    def get_scores(
        self,
        subjective_state: SubjectiveState,
    ) -> Value:
        """
        Returns:
            Return scores trained by this contextual bandit algorithm
        """
        pass

Classes

class ContextualBanditBase (feature_dim: int, exploration_module: ExplorationModule, training_rounds: int = 100, batch_size: int = 128)

A base class for Contextual Bandit policy learner.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class ContextualBanditBase(PolicyLearner):
    """
    A base class for Contextual Bandit policy learner.
    """

    def __init__(
        self,
        feature_dim: int,
        exploration_module: ExplorationModule,
        training_rounds: int = 100,
        batch_size: int = 128,
    ) -> None:
        super(ContextualBanditBase, self).__init__(
            training_rounds=training_rounds,
            batch_size=batch_size,
            exploration_module=exploration_module,
            on_policy=False,
            is_action_continuous=False,  # TODO change in subclasses when we add CB for continuous
        )
        self._feature_dim = feature_dim

    @property
    def feature_dim(self) -> int:
        return self._feature_dim

    @abstractmethod
    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        pass

    @abstractmethod
    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: ActionSpace,
        action_availability_mask: Optional[torch.Tensor] = None,
        exploit: bool = False,
    ) -> Action:
        pass

    @abstractmethod
    def get_scores(
        self,
        subjective_state: SubjectiveState,
    ) -> Value:
        """
        Returns:
            Return scores trained by this contextual bandit algorithm
        """
        pass

Ancestors

Subclasses

Instance variables

var feature_dim : int
Expand source code
@property
def feature_dim(self) -> int:
    return self._feature_dim

Methods

def act(self, subjective_state: torch.Tensor, available_action_space: ActionSpace, action_availability_mask: Optional[torch.Tensor] = None, exploit: bool = False) ‑> torch.Tensor
Expand source code
@abstractmethod
def act(
    self,
    subjective_state: SubjectiveState,
    available_action_space: ActionSpace,
    action_availability_mask: Optional[torch.Tensor] = None,
    exploit: bool = False,
) -> Action:
    pass
def get_scores(self, subjective_state: torch.Tensor) ‑> object

Returns

Return scores trained by this contextual bandit algorithm

Expand source code
@abstractmethod
def get_scores(
    self,
    subjective_state: SubjectiveState,
) -> Value:
    """
    Returns:
        Return scores trained by this contextual bandit algorithm
    """
    pass

Inherited members