Module pearl.policy_learners.contextual_bandits.linear_bandit

Expand source code
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, Dict, Optional

import torch
from pearl.api.action import Action
from pearl.history_summarization_modules.history_summarization_module import (
    SubjectiveState,
)
from pearl.policy_learners.contextual_bandits.contextual_bandit_base import (
    ContextualBanditBase,
    DEFAULT_ACTION_SPACE,
)
from pearl.policy_learners.exploration_modules.common.score_exploration_base import (
    ScoreExplorationBase,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
    ExplorationModule,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.learning.action_utils import (
    concatenate_actions_to_state,
)
from pearl.utils.functional_utils.learning.linear_regression import LinearRegression
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace


class LinearBandit(ContextualBanditBase):
    """
    Policy Learner for Contextual Bandit with Linear Policy
    """

    def __init__(
        self,
        feature_dim: int,
        exploration_module: Optional[ExplorationModule] = None,
        l2_reg_lambda: float = 1.0,
        training_rounds: int = 100,
        batch_size: int = 128,
    ) -> None:
        super(LinearBandit, self).__init__(
            feature_dim=feature_dim,
            training_rounds=training_rounds,
            batch_size=batch_size,
            # pyre-fixme[6]: For 4th argument expected `ExplorationModule` but got
            #  `Optional[ExplorationModule]`.
            exploration_module=exploration_module,
        )
        self.model = LinearRegression(
            feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda
        )

    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        """
        A <- A + x*x.t
        b <- b + r*x
        """
        x = torch.cat([batch.state, batch.action], dim=1)
        assert batch.weight is not None
        self.model.learn_batch(
            x=x,
            y=batch.reward,
            weight=batch.weight,
        )
        current_values = self.model(x)
        return {"current_values": current_values.mean().item()}

    # pyre-fixme[14]: `act` overrides method defined in `ContextualBanditBase`
    #  inconsistently.
    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: DiscreteActionSpace,
        action_availability_mask: Optional[torch.Tensor] = None,
        exploit: bool = False,
    ) -> Action:
        """
        Args:
            subjective_state: state will be applied to different action vectors in action_space
            available_action_space: contains a list of action vectors.
                                    Currently, only static spaces are supported.
        Return:
            action index chosen given state and action vectors
        """
        # It doesnt make sense to call act if we are not working with action vector
        assert (
            self._exploration_module is not None
        ), "exploration module must be set to call act()"
        action_count = available_action_space.n
        new_feature = concatenate_actions_to_state(
            subjective_state=subjective_state, action_space=available_action_space
        )
        values = self.model(new_feature)  # (batch_size, action_count)
        assert values.shape == (new_feature.shape[0], action_count)
        return self._exploration_module.act(
            subjective_state=new_feature,
            action_space=available_action_space,
            values=values,
            action_availability_mask=action_availability_mask,
            representation=self.model,
        )

    def get_scores(
        self,
        subjective_state: SubjectiveState,
        action_space: DiscreteActionSpace = DEFAULT_ACTION_SPACE,
    ) -> torch.Tensor:
        """
        Returns:
            UCB scores when exploration module is UCB
            Shape is (batch)
        """
        assert isinstance(self._exploration_module, ScoreExplorationBase)
        feature = concatenate_actions_to_state(
            subjective_state=subjective_state, action_space=action_space
        )
        assert isinstance(self._exploration_module, ScoreExplorationBase)
        return self._exploration_module.get_scores(
            subjective_state=feature,
            values=self.model(feature),
            action_space=action_space,
            representation=self.model,
        ).squeeze()

Classes

class LinearBandit (feature_dim: int, exploration_module: Optional[ExplorationModule] = None, l2_reg_lambda: float = 1.0, training_rounds: int = 100, batch_size: int = 128)

Policy Learner for Contextual Bandit with Linear Policy

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class LinearBandit(ContextualBanditBase):
    """
    Policy Learner for Contextual Bandit with Linear Policy
    """

    def __init__(
        self,
        feature_dim: int,
        exploration_module: Optional[ExplorationModule] = None,
        l2_reg_lambda: float = 1.0,
        training_rounds: int = 100,
        batch_size: int = 128,
    ) -> None:
        super(LinearBandit, self).__init__(
            feature_dim=feature_dim,
            training_rounds=training_rounds,
            batch_size=batch_size,
            # pyre-fixme[6]: For 4th argument expected `ExplorationModule` but got
            #  `Optional[ExplorationModule]`.
            exploration_module=exploration_module,
        )
        self.model = LinearRegression(
            feature_dim=feature_dim, l2_reg_lambda=l2_reg_lambda
        )

    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        """
        A <- A + x*x.t
        b <- b + r*x
        """
        x = torch.cat([batch.state, batch.action], dim=1)
        assert batch.weight is not None
        self.model.learn_batch(
            x=x,
            y=batch.reward,
            weight=batch.weight,
        )
        current_values = self.model(x)
        return {"current_values": current_values.mean().item()}

    # pyre-fixme[14]: `act` overrides method defined in `ContextualBanditBase`
    #  inconsistently.
    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: DiscreteActionSpace,
        action_availability_mask: Optional[torch.Tensor] = None,
        exploit: bool = False,
    ) -> Action:
        """
        Args:
            subjective_state: state will be applied to different action vectors in action_space
            available_action_space: contains a list of action vectors.
                                    Currently, only static spaces are supported.
        Return:
            action index chosen given state and action vectors
        """
        # It doesnt make sense to call act if we are not working with action vector
        assert (
            self._exploration_module is not None
        ), "exploration module must be set to call act()"
        action_count = available_action_space.n
        new_feature = concatenate_actions_to_state(
            subjective_state=subjective_state, action_space=available_action_space
        )
        values = self.model(new_feature)  # (batch_size, action_count)
        assert values.shape == (new_feature.shape[0], action_count)
        return self._exploration_module.act(
            subjective_state=new_feature,
            action_space=available_action_space,
            values=values,
            action_availability_mask=action_availability_mask,
            representation=self.model,
        )

    def get_scores(
        self,
        subjective_state: SubjectiveState,
        action_space: DiscreteActionSpace = DEFAULT_ACTION_SPACE,
    ) -> torch.Tensor:
        """
        Returns:
            UCB scores when exploration module is UCB
            Shape is (batch)
        """
        assert isinstance(self._exploration_module, ScoreExplorationBase)
        feature = concatenate_actions_to_state(
            subjective_state=subjective_state, action_space=action_space
        )
        assert isinstance(self._exploration_module, ScoreExplorationBase)
        return self._exploration_module.get_scores(
            subjective_state=feature,
            values=self.model(feature),
            action_space=action_space,
            representation=self.model,
        ).squeeze()

Ancestors

Methods

def act(self, subjective_state: torch.Tensor, available_action_space: DiscreteActionSpace, action_availability_mask: Optional[torch.Tensor] = None, exploit: bool = False) ‑> torch.Tensor

Args

subjective_state
state will be applied to different action vectors in action_space
available_action_space
contains a list of action vectors. Currently, only static spaces are supported.

Return

action index chosen given state and action vectors

Expand source code
def act(
    self,
    subjective_state: SubjectiveState,
    available_action_space: DiscreteActionSpace,
    action_availability_mask: Optional[torch.Tensor] = None,
    exploit: bool = False,
) -> Action:
    """
    Args:
        subjective_state: state will be applied to different action vectors in action_space
        available_action_space: contains a list of action vectors.
                                Currently, only static spaces are supported.
    Return:
        action index chosen given state and action vectors
    """
    # It doesnt make sense to call act if we are not working with action vector
    assert (
        self._exploration_module is not None
    ), "exploration module must be set to call act()"
    action_count = available_action_space.n
    new_feature = concatenate_actions_to_state(
        subjective_state=subjective_state, action_space=available_action_space
    )
    values = self.model(new_feature)  # (batch_size, action_count)
    assert values.shape == (new_feature.shape[0], action_count)
    return self._exploration_module.act(
        subjective_state=new_feature,
        action_space=available_action_space,
        values=values,
        action_availability_mask=action_availability_mask,
        representation=self.model,
    )
def get_scores(self, subjective_state: torch.Tensor, action_space: DiscreteActionSpace = <pearl.utils.instantiations.spaces.discrete_action.DiscreteActionSpace object>) ‑> torch.Tensor

Returns

UCB scores when exploration module is UCB Shape is (batch)

Expand source code
def get_scores(
    self,
    subjective_state: SubjectiveState,
    action_space: DiscreteActionSpace = DEFAULT_ACTION_SPACE,
) -> torch.Tensor:
    """
    Returns:
        UCB scores when exploration module is UCB
        Shape is (batch)
    """
    assert isinstance(self._exploration_module, ScoreExplorationBase)
    feature = concatenate_actions_to_state(
        subjective_state=subjective_state, action_space=action_space
    )
    assert isinstance(self._exploration_module, ScoreExplorationBase)
    return self._exploration_module.get_scores(
        subjective_state=feature,
        values=self.model(feature),
        action_space=action_space,
        representation=self.model,
    ).squeeze()
def learn_batch(self, batch: TransitionBatch) ‑> Dict[str, Any]

A <- A + xx.t b <- b + rx

Expand source code
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
    """
    A <- A + x*x.t
    b <- b + r*x
    """
    x = torch.cat([batch.state, batch.action], dim=1)
    assert batch.weight is not None
    self.model.learn_batch(
        x=x,
        y=batch.reward,
        weight=batch.weight,
    )
    current_values = self.model(x)
    return {"current_values": current_values.mean().item()}

Inherited members