Module pearl.policy_learners.sequential_decision_making.quantile_regression_deep_q_learning
Expand source code
# import copy
from typing import List, Optional
import torch
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
)
# import torch.optim as optim
from pearl.api.action_space import ActionSpace
from pearl.neural_networks.common.value_networks import QuantileQValueNetwork
from pearl.policy_learners.exploration_modules.common.epsilon_greedy_exploration import (
EGreedyExploration,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.sequential_decision_making.quantile_regression_deep_td_learning import (
QuantileRegressionDeepTDLearning,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
class QuantileRegressionDeepQLearning(QuantileRegressionDeepTDLearning):
"""
Quantile Regression based Deep Q Learning Policy Learner
Notes:
- Support for offline learning by adding a conservative loss to the
quantile regression based distributional
temporal difference loss has not been added (literature does not seem to have that)
- To do: Add support for input a network instance
"""
def __init__(
self,
state_dim: int,
action_space: ActionSpace,
hidden_dims: Optional[List[int]] = None,
num_quantiles: int = 10,
exploration_module: Optional[ExplorationModule] = None,
on_policy: bool = False,
learning_rate: float = 5 * 0.0001,
discount_factor: float = 0.99,
training_rounds: int = 100,
batch_size: int = 128,
target_update_freq: int = 10,
soft_update_tau: float = 0.05,
action_representation_module: Optional[ActionRepresentationModule] = None,
) -> None:
assert isinstance(action_space, DiscreteActionSpace)
super(QuantileRegressionDeepQLearning, self).__init__(
state_dim=state_dim,
action_space=action_space,
on_policy=on_policy,
exploration_module=exploration_module
if exploration_module is not None
else EGreedyExploration(0.10),
hidden_dims=hidden_dims,
num_quantiles=num_quantiles,
learning_rate=learning_rate,
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
target_update_freq=target_update_freq,
soft_update_tau=soft_update_tau,
network_type=QuantileQValueNetwork, # enforced to be of the type QuantileQValueNetwork
action_representation_module=action_representation_module,
)
# QR-DQN is based on QuantileRegressionDeepTDLearning class.
@torch.no_grad()
def _get_next_state_quantiles(
self, batch: TransitionBatch, batch_size: int
) -> torch.Tensor:
next_state_batch = batch.next_state # (batch_size x state_dim)
next_available_actions_batch = (
batch.next_available_actions
) # shape: (batch_size x action_space_size x action_dim)
next_unavailable_actions_mask_batch = (
batch.next_unavailable_actions_mask
) # shape: (batch_size x action_space_size)
assert next_state_batch is not None
assert isinstance(self._action_space, DiscreteActionSpace)
next_state_batch_repeated = torch.repeat_interleave(
next_state_batch.unsqueeze(1),
self._action_space.n, # pyre-ignore[16]
dim=1,
) # shape: (batch_size x action_space_size x state_dim)
"""
Step 1: get quantiles for all possible actions in the batch
- output shape: (batch_size x action_space_size x num_quantiles)
"""
assert next_available_actions_batch is not None
next_state_action_quantiles = self._Q_target.get_q_value_distribution(
next_state_batch_repeated, next_available_actions_batch
)
# get q values from a q value distribution under a risk metric
# instead of using the 'get_q_values' method of the QuantileQValueNetwork,
# we invoke a method from the risk sensitive safety module
next_state_action_values = self.safety_module.get_q_values_under_risk_metric(
next_state_batch_repeated, next_available_actions_batch, self._Q_target
).view(
batch_size, -1
) # shape: (batch_size, action_space_size)
# make sure that unavailable actions' Q values are assigned to -inf
next_state_action_values[next_unavailable_actions_mask_batch] = -float("inf")
"""
Step 2: choose the greedy action for each state
"""
greedy_action_idx = torch.argmax(next_state_action_values, dim=-1).unsqueeze(-1)
"""
Step 3: get quantiles corresponding to greedy action index using torch.gather
- as the shape of next_state_action_quantiles is
(batch_size x action_space_size x num_quantiles),
and the shape of greedy_action_idx is (batch_size x 1),
- we need to expand the shape of greedy_action_idx along the last dimension for
broadcasting
"""
quantiles_greedy_action = torch.gather(
input=next_state_action_quantiles,
dim=1,
index=greedy_action_idx.unsqueeze(-1).expand(
-1, -1, next_state_action_quantiles.shape[-1]
), # expands shape to (batch_size x 1 x num_quantiles)
)
return quantiles_greedy_action.view(batch_size, -1) # shape: (batch_size, N)
Classes
class QuantileRegressionDeepQLearning (state_dim: int, action_space: ActionSpace, hidden_dims: Optional[List[int]] = None, num_quantiles: int = 10, exploration_module: Optional[ExplorationModule] = None, on_policy: bool = False, learning_rate: float = 0.0005, discount_factor: float = 0.99, training_rounds: int = 100, batch_size: int = 128, target_update_freq: int = 10, soft_update_tau: float = 0.05, action_representation_module: Optional[ActionRepresentationModule] = None)
-
Quantile Regression based Deep Q Learning Policy Learner
Notes
- Support for offline learning by adding a conservative loss to the quantile regression based distributional temporal difference loss has not been added (literature does not seem to have that)
- To do: Add support for input a network instance
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class QuantileRegressionDeepQLearning(QuantileRegressionDeepTDLearning): """ Quantile Regression based Deep Q Learning Policy Learner Notes: - Support for offline learning by adding a conservative loss to the quantile regression based distributional temporal difference loss has not been added (literature does not seem to have that) - To do: Add support for input a network instance """ def __init__( self, state_dim: int, action_space: ActionSpace, hidden_dims: Optional[List[int]] = None, num_quantiles: int = 10, exploration_module: Optional[ExplorationModule] = None, on_policy: bool = False, learning_rate: float = 5 * 0.0001, discount_factor: float = 0.99, training_rounds: int = 100, batch_size: int = 128, target_update_freq: int = 10, soft_update_tau: float = 0.05, action_representation_module: Optional[ActionRepresentationModule] = None, ) -> None: assert isinstance(action_space, DiscreteActionSpace) super(QuantileRegressionDeepQLearning, self).__init__( state_dim=state_dim, action_space=action_space, on_policy=on_policy, exploration_module=exploration_module if exploration_module is not None else EGreedyExploration(0.10), hidden_dims=hidden_dims, num_quantiles=num_quantiles, learning_rate=learning_rate, discount_factor=discount_factor, training_rounds=training_rounds, batch_size=batch_size, target_update_freq=target_update_freq, soft_update_tau=soft_update_tau, network_type=QuantileQValueNetwork, # enforced to be of the type QuantileQValueNetwork action_representation_module=action_representation_module, ) # QR-DQN is based on QuantileRegressionDeepTDLearning class. @torch.no_grad() def _get_next_state_quantiles( self, batch: TransitionBatch, batch_size: int ) -> torch.Tensor: next_state_batch = batch.next_state # (batch_size x state_dim) next_available_actions_batch = ( batch.next_available_actions ) # shape: (batch_size x action_space_size x action_dim) next_unavailable_actions_mask_batch = ( batch.next_unavailable_actions_mask ) # shape: (batch_size x action_space_size) assert next_state_batch is not None assert isinstance(self._action_space, DiscreteActionSpace) next_state_batch_repeated = torch.repeat_interleave( next_state_batch.unsqueeze(1), self._action_space.n, # pyre-ignore[16] dim=1, ) # shape: (batch_size x action_space_size x state_dim) """ Step 1: get quantiles for all possible actions in the batch - output shape: (batch_size x action_space_size x num_quantiles) """ assert next_available_actions_batch is not None next_state_action_quantiles = self._Q_target.get_q_value_distribution( next_state_batch_repeated, next_available_actions_batch ) # get q values from a q value distribution under a risk metric # instead of using the 'get_q_values' method of the QuantileQValueNetwork, # we invoke a method from the risk sensitive safety module next_state_action_values = self.safety_module.get_q_values_under_risk_metric( next_state_batch_repeated, next_available_actions_batch, self._Q_target ).view( batch_size, -1 ) # shape: (batch_size, action_space_size) # make sure that unavailable actions' Q values are assigned to -inf next_state_action_values[next_unavailable_actions_mask_batch] = -float("inf") """ Step 2: choose the greedy action for each state """ greedy_action_idx = torch.argmax(next_state_action_values, dim=-1).unsqueeze(-1) """ Step 3: get quantiles corresponding to greedy action index using torch.gather - as the shape of next_state_action_quantiles is (batch_size x action_space_size x num_quantiles), and the shape of greedy_action_idx is (batch_size x 1), - we need to expand the shape of greedy_action_idx along the last dimension for broadcasting """ quantiles_greedy_action = torch.gather( input=next_state_action_quantiles, dim=1, index=greedy_action_idx.unsqueeze(-1).expand( -1, -1, next_state_action_quantiles.shape[-1] ), # expands shape to (batch_size x 1 x num_quantiles) ) return quantiles_greedy_action.view(batch_size, -1) # shape: (batch_size, N)
Ancestors
- QuantileRegressionDeepTDLearning
- DistributionalPolicyLearner
- PolicyLearner
- torch.nn.modules.module.Module
- abc.ABC
Inherited members