Module pearl.policy_learners.sequential_decision_making.tabular_q_learning

Expand source code
import random
from typing import Any, Dict, Iterable, List, Tuple

import torch

from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.reward import Reward, Value
from pearl.history_summarization_modules.history_summarization_module import (
    SubjectiveState,
)
from pearl.policy_learners.exploration_modules.common.epsilon_greedy_exploration import (
    EGreedyExploration,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.instantiations.spaces.discrete import DiscreteSpace

# TODO: make package names and organization more consistent
# TODO: This class currently assumes action index, not generic DiscreteActionSpace.
#   Need to fix this.


class TabularQLearning(PolicyLearner):
    """
    A tabular Q-learning policy learner.
    """

    def __init__(
        self,
        learning_rate: float = 0.01,
        discount_factor: float = 0.9,
        exploration_rate: float = 0.01,
        debug: bool = False,
    ) -> None:
        """
        Initializes the tabular Q-learning policy learner.

        Args:
            learning_rate (float, optional): the learning rate. Defaults to 0.01.
            discount_factor (float, optional): the discount factor. Defaults to 0.9.
            exploration_rate (float, optional): the exploration rate. Defaults to 0.01.
            debug (bool, optional): whether to print debug information to standard output.
            Defaults to False.
        """
        super(TabularQLearning, self).__init__(
            exploration_module=EGreedyExploration(exploration_rate),
            on_policy=False,
            is_action_continuous=False,
            requires_tensors=False,  # temporary solution before abstract interfaces
        )
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.q_values: Dict[Tuple[SubjectiveState, int], Value] = {}
        self.debug: bool = debug

    def reset(self, action_space: ActionSpace) -> None:
        self._action_space = action_space

    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: ActionSpace,
        exploit: bool = False,
    ) -> Action:
        assert isinstance(available_action_space, DiscreteSpace)
        # FIXME: this conversion should be eliminated once Action
        # is no longer constrained to be a Tensor.
        actions_as_ints: List[int] = [int(a.item()) for a in available_action_space]
        # Choose the action with the highest Q-value for the current state.
        q_values_for_state = {
            action: self.q_values.get((subjective_state, action), 0)
            for action in actions_as_ints
        }
        #  `Iterable[Variable[SupportsRichComparisonT (bound to
        #  Union[SupportsDunderGT[typing.Any], SupportsDunderLT[typing.Any]])]]` but
        #  got `dict_values[int, Number]`.
        # Fixing this will require that Value is defined so it supports
        # rich comparisons.
        max_q_value = max(q_values_for_state.values())
        best_actions = [
            action
            for action, q_value in q_values_for_state.items()
            if q_value == max_q_value
        ]
        exploit_action = random.choice(best_actions)
        exploit_action = torch.tensor([exploit_action])
        if exploit:
            return exploit_action

        return self._exploration_module.act(
            subjective_state,
            available_action_space,
            exploit_action,
        )

    def learn(
        self,
        replay_buffer: ReplayBuffer,
    ) -> Dict[str, Any]:
        # We know the sampling result from SingleTransitionReplayBuffer
        # is a list with a single tuple.
        transitions = replay_buffer.sample(1)
        assert isinstance(transitions, Iterable)
        transition = next(iter(transitions))
        assert isinstance(transition, Iterable)
        # We currently assume replay buffer only contains last transition (on-policy)
        (
            state,
            action,
            reward,
            next_state,
            _curr_available_actions,
            _next_available_actions,
            done,
            _max_number_actions,
            _cost,
        ) = transition
        old_q_value = self.q_values.get((state, action.item()), 0)
        next_q_values = [
            self.q_values.get((next_state, next_action.item()), 0)
            for next_action in self._action_space
        ]

        if done:
            next_state_value = 0
        else:
            # pyre-fixme[6]: For 1st argument expected
            #  `Iterable[Variable[SupportsRichComparisonT (bound to
            #  Union[SupportsDunderGT[typing.Any],
            #  SupportsDunderLT[typing.Any]])]]` but got `List[Number]`.
            max_next_q_value = max(next_q_values) if next_q_values else 0
            next_state_value = self.discount_factor * max_next_q_value

        # pyre-fixme[58]: `+` is not supported for operand types `Number` and
        #  `float`.
        # FIXME: not finding a generic assertion that would fix this.
        # assert isinstance(old_q_value, Union[torch.Tensor, int, float])
        # does not work. Pending discussion.
        new_q_value = old_q_value + self.learning_rate * (
            reward + next_state_value - old_q_value
        )

        self.q_values[(state, action.item())] = new_q_value

        if self.debug:
            self.print_debug_information(state, action, reward, next_state, done)

        return {
            "state": state,
            "action": action,
            "reward": reward,
            "next_state": next_state,
            "done": done,
        }

    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        raise Exception("tabular_q_learning doesnt need learn_batch")

    def print_debug_information(
        self,
        state: SubjectiveState,
        action: Action,
        reward: Reward,
        next_state: SubjectiveState,
        done: bool,
    ) -> None:
        print("state:", state)
        print("action:", action)
        print("reward:", reward)
        print("next state:", next_state)
        print("done:", done)
        print("q-values:", self.q_values)

    def __str__(self) -> str:
        exploration_module = self._exploration_module
        assert isinstance(exploration_module, EGreedyExploration)
        items = [
            "α=" + str(self.learning_rate),
            "ε=" + str(exploration_module.epsilon),
            "λ=" + str(self.discount_factor),
        ]
        return "Q-Learning" + (
            " (" + ", ".join(str(item) for item in items) + ")" if items else ""
        )

Classes

class TabularQLearning (learning_rate: float = 0.01, discount_factor: float = 0.9, exploration_rate: float = 0.01, debug: bool = False)

A tabular Q-learning policy learner.

Initializes the tabular Q-learning policy learner.

Args

learning_rate : float, optional
the learning rate. Defaults to 0.01.
discount_factor : float, optional
the discount factor. Defaults to 0.9.
exploration_rate : float, optional
the exploration rate. Defaults to 0.01.
debug : bool, optional
whether to print debug information to standard output.

Defaults to False.

Expand source code
class TabularQLearning(PolicyLearner):
    """
    A tabular Q-learning policy learner.
    """

    def __init__(
        self,
        learning_rate: float = 0.01,
        discount_factor: float = 0.9,
        exploration_rate: float = 0.01,
        debug: bool = False,
    ) -> None:
        """
        Initializes the tabular Q-learning policy learner.

        Args:
            learning_rate (float, optional): the learning rate. Defaults to 0.01.
            discount_factor (float, optional): the discount factor. Defaults to 0.9.
            exploration_rate (float, optional): the exploration rate. Defaults to 0.01.
            debug (bool, optional): whether to print debug information to standard output.
            Defaults to False.
        """
        super(TabularQLearning, self).__init__(
            exploration_module=EGreedyExploration(exploration_rate),
            on_policy=False,
            is_action_continuous=False,
            requires_tensors=False,  # temporary solution before abstract interfaces
        )
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.q_values: Dict[Tuple[SubjectiveState, int], Value] = {}
        self.debug: bool = debug

    def reset(self, action_space: ActionSpace) -> None:
        self._action_space = action_space

    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: ActionSpace,
        exploit: bool = False,
    ) -> Action:
        assert isinstance(available_action_space, DiscreteSpace)
        # FIXME: this conversion should be eliminated once Action
        # is no longer constrained to be a Tensor.
        actions_as_ints: List[int] = [int(a.item()) for a in available_action_space]
        # Choose the action with the highest Q-value for the current state.
        q_values_for_state = {
            action: self.q_values.get((subjective_state, action), 0)
            for action in actions_as_ints
        }
        #  `Iterable[Variable[SupportsRichComparisonT (bound to
        #  Union[SupportsDunderGT[typing.Any], SupportsDunderLT[typing.Any]])]]` but
        #  got `dict_values[int, Number]`.
        # Fixing this will require that Value is defined so it supports
        # rich comparisons.
        max_q_value = max(q_values_for_state.values())
        best_actions = [
            action
            for action, q_value in q_values_for_state.items()
            if q_value == max_q_value
        ]
        exploit_action = random.choice(best_actions)
        exploit_action = torch.tensor([exploit_action])
        if exploit:
            return exploit_action

        return self._exploration_module.act(
            subjective_state,
            available_action_space,
            exploit_action,
        )

    def learn(
        self,
        replay_buffer: ReplayBuffer,
    ) -> Dict[str, Any]:
        # We know the sampling result from SingleTransitionReplayBuffer
        # is a list with a single tuple.
        transitions = replay_buffer.sample(1)
        assert isinstance(transitions, Iterable)
        transition = next(iter(transitions))
        assert isinstance(transition, Iterable)
        # We currently assume replay buffer only contains last transition (on-policy)
        (
            state,
            action,
            reward,
            next_state,
            _curr_available_actions,
            _next_available_actions,
            done,
            _max_number_actions,
            _cost,
        ) = transition
        old_q_value = self.q_values.get((state, action.item()), 0)
        next_q_values = [
            self.q_values.get((next_state, next_action.item()), 0)
            for next_action in self._action_space
        ]

        if done:
            next_state_value = 0
        else:
            # pyre-fixme[6]: For 1st argument expected
            #  `Iterable[Variable[SupportsRichComparisonT (bound to
            #  Union[SupportsDunderGT[typing.Any],
            #  SupportsDunderLT[typing.Any]])]]` but got `List[Number]`.
            max_next_q_value = max(next_q_values) if next_q_values else 0
            next_state_value = self.discount_factor * max_next_q_value

        # pyre-fixme[58]: `+` is not supported for operand types `Number` and
        #  `float`.
        # FIXME: not finding a generic assertion that would fix this.
        # assert isinstance(old_q_value, Union[torch.Tensor, int, float])
        # does not work. Pending discussion.
        new_q_value = old_q_value + self.learning_rate * (
            reward + next_state_value - old_q_value
        )

        self.q_values[(state, action.item())] = new_q_value

        if self.debug:
            self.print_debug_information(state, action, reward, next_state, done)

        return {
            "state": state,
            "action": action,
            "reward": reward,
            "next_state": next_state,
            "done": done,
        }

    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        raise Exception("tabular_q_learning doesnt need learn_batch")

    def print_debug_information(
        self,
        state: SubjectiveState,
        action: Action,
        reward: Reward,
        next_state: SubjectiveState,
        done: bool,
    ) -> None:
        print("state:", state)
        print("action:", action)
        print("reward:", reward)
        print("next state:", next_state)
        print("done:", done)
        print("q-values:", self.q_values)

    def __str__(self) -> str:
        exploration_module = self._exploration_module
        assert isinstance(exploration_module, EGreedyExploration)
        items = [
            "α=" + str(self.learning_rate),
            "ε=" + str(exploration_module.epsilon),
            "λ=" + str(self.discount_factor),
        ]
        return "Q-Learning" + (
            " (" + ", ".join(str(item) for item in items) + ")" if items else ""
        )

Ancestors

Methods

def act(self, subjective_state: torch.Tensor, available_action_space: ActionSpace, exploit: bool = False) ‑> torch.Tensor
Expand source code
def act(
    self,
    subjective_state: SubjectiveState,
    available_action_space: ActionSpace,
    exploit: bool = False,
) -> Action:
    assert isinstance(available_action_space, DiscreteSpace)
    # FIXME: this conversion should be eliminated once Action
    # is no longer constrained to be a Tensor.
    actions_as_ints: List[int] = [int(a.item()) for a in available_action_space]
    # Choose the action with the highest Q-value for the current state.
    q_values_for_state = {
        action: self.q_values.get((subjective_state, action), 0)
        for action in actions_as_ints
    }
    #  `Iterable[Variable[SupportsRichComparisonT (bound to
    #  Union[SupportsDunderGT[typing.Any], SupportsDunderLT[typing.Any]])]]` but
    #  got `dict_values[int, Number]`.
    # Fixing this will require that Value is defined so it supports
    # rich comparisons.
    max_q_value = max(q_values_for_state.values())
    best_actions = [
        action
        for action, q_value in q_values_for_state.items()
        if q_value == max_q_value
    ]
    exploit_action = random.choice(best_actions)
    exploit_action = torch.tensor([exploit_action])
    if exploit:
        return exploit_action

    return self._exploration_module.act(
        subjective_state,
        available_action_space,
        exploit_action,
    )
def print_debug_information(self, state: torch.Tensor, action: torch.Tensor, reward: object, next_state: torch.Tensor, done: bool) ‑> None
Expand source code
def print_debug_information(
    self,
    state: SubjectiveState,
    action: Action,
    reward: Reward,
    next_state: SubjectiveState,
    done: bool,
) -> None:
    print("state:", state)
    print("action:", action)
    print("reward:", reward)
    print("next state:", next_state)
    print("done:", done)
    print("q-values:", self.q_values)

Inherited members