Module pearl.policy_learners.sequential_decision_making.tabular_q_learning
Expand source code
import random
from typing import Any, Dict, Iterable, List, Tuple
import torch
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.reward import Reward, Value
from pearl.history_summarization_modules.history_summarization_module import (
SubjectiveState,
)
from pearl.policy_learners.exploration_modules.common.epsilon_greedy_exploration import (
EGreedyExploration,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.instantiations.spaces.discrete import DiscreteSpace
# TODO: make package names and organization more consistent
# TODO: This class currently assumes action index, not generic DiscreteActionSpace.
# Need to fix this.
class TabularQLearning(PolicyLearner):
"""
A tabular Q-learning policy learner.
"""
def __init__(
self,
learning_rate: float = 0.01,
discount_factor: float = 0.9,
exploration_rate: float = 0.01,
debug: bool = False,
) -> None:
"""
Initializes the tabular Q-learning policy learner.
Args:
learning_rate (float, optional): the learning rate. Defaults to 0.01.
discount_factor (float, optional): the discount factor. Defaults to 0.9.
exploration_rate (float, optional): the exploration rate. Defaults to 0.01.
debug (bool, optional): whether to print debug information to standard output.
Defaults to False.
"""
super(TabularQLearning, self).__init__(
exploration_module=EGreedyExploration(exploration_rate),
on_policy=False,
is_action_continuous=False,
requires_tensors=False, # temporary solution before abstract interfaces
)
self.learning_rate = learning_rate
self.discount_factor = discount_factor
self.q_values: Dict[Tuple[SubjectiveState, int], Value] = {}
self.debug: bool = debug
def reset(self, action_space: ActionSpace) -> None:
self._action_space = action_space
def act(
self,
subjective_state: SubjectiveState,
available_action_space: ActionSpace,
exploit: bool = False,
) -> Action:
assert isinstance(available_action_space, DiscreteSpace)
# FIXME: this conversion should be eliminated once Action
# is no longer constrained to be a Tensor.
actions_as_ints: List[int] = [int(a.item()) for a in available_action_space]
# Choose the action with the highest Q-value for the current state.
q_values_for_state = {
action: self.q_values.get((subjective_state, action), 0)
for action in actions_as_ints
}
# `Iterable[Variable[SupportsRichComparisonT (bound to
# Union[SupportsDunderGT[typing.Any], SupportsDunderLT[typing.Any]])]]` but
# got `dict_values[int, Number]`.
# Fixing this will require that Value is defined so it supports
# rich comparisons.
max_q_value = max(q_values_for_state.values())
best_actions = [
action
for action, q_value in q_values_for_state.items()
if q_value == max_q_value
]
exploit_action = random.choice(best_actions)
exploit_action = torch.tensor([exploit_action])
if exploit:
return exploit_action
return self._exploration_module.act(
subjective_state,
available_action_space,
exploit_action,
)
def learn(
self,
replay_buffer: ReplayBuffer,
) -> Dict[str, Any]:
# We know the sampling result from SingleTransitionReplayBuffer
# is a list with a single tuple.
transitions = replay_buffer.sample(1)
assert isinstance(transitions, Iterable)
transition = next(iter(transitions))
assert isinstance(transition, Iterable)
# We currently assume replay buffer only contains last transition (on-policy)
(
state,
action,
reward,
next_state,
_curr_available_actions,
_next_available_actions,
done,
_max_number_actions,
_cost,
) = transition
old_q_value = self.q_values.get((state, action.item()), 0)
next_q_values = [
self.q_values.get((next_state, next_action.item()), 0)
for next_action in self._action_space
]
if done:
next_state_value = 0
else:
# pyre-fixme[6]: For 1st argument expected
# `Iterable[Variable[SupportsRichComparisonT (bound to
# Union[SupportsDunderGT[typing.Any],
# SupportsDunderLT[typing.Any]])]]` but got `List[Number]`.
max_next_q_value = max(next_q_values) if next_q_values else 0
next_state_value = self.discount_factor * max_next_q_value
# pyre-fixme[58]: `+` is not supported for operand types `Number` and
# `float`.
# FIXME: not finding a generic assertion that would fix this.
# assert isinstance(old_q_value, Union[torch.Tensor, int, float])
# does not work. Pending discussion.
new_q_value = old_q_value + self.learning_rate * (
reward + next_state_value - old_q_value
)
self.q_values[(state, action.item())] = new_q_value
if self.debug:
self.print_debug_information(state, action, reward, next_state, done)
return {
"state": state,
"action": action,
"reward": reward,
"next_state": next_state,
"done": done,
}
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
raise Exception("tabular_q_learning doesnt need learn_batch")
def print_debug_information(
self,
state: SubjectiveState,
action: Action,
reward: Reward,
next_state: SubjectiveState,
done: bool,
) -> None:
print("state:", state)
print("action:", action)
print("reward:", reward)
print("next state:", next_state)
print("done:", done)
print("q-values:", self.q_values)
def __str__(self) -> str:
exploration_module = self._exploration_module
assert isinstance(exploration_module, EGreedyExploration)
items = [
"α=" + str(self.learning_rate),
"ε=" + str(exploration_module.epsilon),
"λ=" + str(self.discount_factor),
]
return "Q-Learning" + (
" (" + ", ".join(str(item) for item in items) + ")" if items else ""
)
Classes
class TabularQLearning (learning_rate: float = 0.01, discount_factor: float = 0.9, exploration_rate: float = 0.01, debug: bool = False)
-
A tabular Q-learning policy learner.
Initializes the tabular Q-learning policy learner.
Args
learning_rate
:float
, optional- the learning rate. Defaults to 0.01.
discount_factor
:float
, optional- the discount factor. Defaults to 0.9.
exploration_rate
:float
, optional- the exploration rate. Defaults to 0.01.
debug
:bool
, optional- whether to print debug information to standard output.
Defaults to False.
Expand source code
class TabularQLearning(PolicyLearner): """ A tabular Q-learning policy learner. """ def __init__( self, learning_rate: float = 0.01, discount_factor: float = 0.9, exploration_rate: float = 0.01, debug: bool = False, ) -> None: """ Initializes the tabular Q-learning policy learner. Args: learning_rate (float, optional): the learning rate. Defaults to 0.01. discount_factor (float, optional): the discount factor. Defaults to 0.9. exploration_rate (float, optional): the exploration rate. Defaults to 0.01. debug (bool, optional): whether to print debug information to standard output. Defaults to False. """ super(TabularQLearning, self).__init__( exploration_module=EGreedyExploration(exploration_rate), on_policy=False, is_action_continuous=False, requires_tensors=False, # temporary solution before abstract interfaces ) self.learning_rate = learning_rate self.discount_factor = discount_factor self.q_values: Dict[Tuple[SubjectiveState, int], Value] = {} self.debug: bool = debug def reset(self, action_space: ActionSpace) -> None: self._action_space = action_space def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, exploit: bool = False, ) -> Action: assert isinstance(available_action_space, DiscreteSpace) # FIXME: this conversion should be eliminated once Action # is no longer constrained to be a Tensor. actions_as_ints: List[int] = [int(a.item()) for a in available_action_space] # Choose the action with the highest Q-value for the current state. q_values_for_state = { action: self.q_values.get((subjective_state, action), 0) for action in actions_as_ints } # `Iterable[Variable[SupportsRichComparisonT (bound to # Union[SupportsDunderGT[typing.Any], SupportsDunderLT[typing.Any]])]]` but # got `dict_values[int, Number]`. # Fixing this will require that Value is defined so it supports # rich comparisons. max_q_value = max(q_values_for_state.values()) best_actions = [ action for action, q_value in q_values_for_state.items() if q_value == max_q_value ] exploit_action = random.choice(best_actions) exploit_action = torch.tensor([exploit_action]) if exploit: return exploit_action return self._exploration_module.act( subjective_state, available_action_space, exploit_action, ) def learn( self, replay_buffer: ReplayBuffer, ) -> Dict[str, Any]: # We know the sampling result from SingleTransitionReplayBuffer # is a list with a single tuple. transitions = replay_buffer.sample(1) assert isinstance(transitions, Iterable) transition = next(iter(transitions)) assert isinstance(transition, Iterable) # We currently assume replay buffer only contains last transition (on-policy) ( state, action, reward, next_state, _curr_available_actions, _next_available_actions, done, _max_number_actions, _cost, ) = transition old_q_value = self.q_values.get((state, action.item()), 0) next_q_values = [ self.q_values.get((next_state, next_action.item()), 0) for next_action in self._action_space ] if done: next_state_value = 0 else: # pyre-fixme[6]: For 1st argument expected # `Iterable[Variable[SupportsRichComparisonT (bound to # Union[SupportsDunderGT[typing.Any], # SupportsDunderLT[typing.Any]])]]` but got `List[Number]`. max_next_q_value = max(next_q_values) if next_q_values else 0 next_state_value = self.discount_factor * max_next_q_value # pyre-fixme[58]: `+` is not supported for operand types `Number` and # `float`. # FIXME: not finding a generic assertion that would fix this. # assert isinstance(old_q_value, Union[torch.Tensor, int, float]) # does not work. Pending discussion. new_q_value = old_q_value + self.learning_rate * ( reward + next_state_value - old_q_value ) self.q_values[(state, action.item())] = new_q_value if self.debug: self.print_debug_information(state, action, reward, next_state, done) return { "state": state, "action": action, "reward": reward, "next_state": next_state, "done": done, } def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: raise Exception("tabular_q_learning doesnt need learn_batch") def print_debug_information( self, state: SubjectiveState, action: Action, reward: Reward, next_state: SubjectiveState, done: bool, ) -> None: print("state:", state) print("action:", action) print("reward:", reward) print("next state:", next_state) print("done:", done) print("q-values:", self.q_values) def __str__(self) -> str: exploration_module = self._exploration_module assert isinstance(exploration_module, EGreedyExploration) items = [ "α=" + str(self.learning_rate), "ε=" + str(exploration_module.epsilon), "λ=" + str(self.discount_factor), ] return "Q-Learning" + ( " (" + ", ".join(str(item) for item in items) + ")" if items else "" )
Ancestors
- PolicyLearner
- torch.nn.modules.module.Module
- abc.ABC
Methods
def act(self, subjective_state: torch.Tensor, available_action_space: ActionSpace, exploit: bool = False) ‑> torch.Tensor
-
Expand source code
def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, exploit: bool = False, ) -> Action: assert isinstance(available_action_space, DiscreteSpace) # FIXME: this conversion should be eliminated once Action # is no longer constrained to be a Tensor. actions_as_ints: List[int] = [int(a.item()) for a in available_action_space] # Choose the action with the highest Q-value for the current state. q_values_for_state = { action: self.q_values.get((subjective_state, action), 0) for action in actions_as_ints } # `Iterable[Variable[SupportsRichComparisonT (bound to # Union[SupportsDunderGT[typing.Any], SupportsDunderLT[typing.Any]])]]` but # got `dict_values[int, Number]`. # Fixing this will require that Value is defined so it supports # rich comparisons. max_q_value = max(q_values_for_state.values()) best_actions = [ action for action, q_value in q_values_for_state.items() if q_value == max_q_value ] exploit_action = random.choice(best_actions) exploit_action = torch.tensor([exploit_action]) if exploit: return exploit_action return self._exploration_module.act( subjective_state, available_action_space, exploit_action, )
def print_debug_information(self, state: torch.Tensor, action: torch.Tensor, reward: object, next_state: torch.Tensor, done: bool) ‑> None
-
Expand source code
def print_debug_information( self, state: SubjectiveState, action: Action, reward: Reward, next_state: SubjectiveState, done: bool, ) -> None: print("state:", state) print("action:", action) print("reward:", reward) print("next state:", next_state) print("done:", done) print("q-values:", self.q_values)
Inherited members