Module pearl.policy_learners.sequential_decision_making.quantile_regression_deep_td_learning
Expand source code
import copy
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Type
import torch
import torch.nn.functional as F
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
)
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.state import SubjectiveState
from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
)
from pearl.neural_networks.common.utils import update_target_network
from pearl.neural_networks.common.value_networks import QuantileQValueNetwork
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.policy_learner import DistributionalPolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from pearl.safety_modules.risk_sensitive_safety_modules import ( # noqa
RiskNeutralSafetyModule, # noqa
)
from pearl.utils.functional_utils.learning.loss_fn_utils import (
compute_elementwise_huber_loss,
)
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from torch import optim
# TODO: Only support discrete action space problems for now and assumes Gym action space.
class QuantileRegressionDeepTDLearning(DistributionalPolicyLearner):
"""
An Abstract Class for Quantile Regression based Deep Temporal Difference learning.
"""
def __init__(
self,
state_dim: int,
action_space: ActionSpace,
on_policy: bool,
exploration_module: ExplorationModule,
hidden_dims: Optional[List[int]] = None,
num_quantiles: int = 10,
learning_rate: float = 5 * 0.0001,
discount_factor: float = 0.99,
training_rounds: int = 100,
batch_size: int = 128,
target_update_freq: int = 10,
soft_update_tau: float = 0.05, # typical value for soft update
network_type: Type[
QuantileQValueNetwork
] = QuantileQValueNetwork, # C51 might use a different network type; add that later
network_instance: Optional[QuantileQValueNetwork] = None,
action_representation_module: Optional[ActionRepresentationModule] = None,
) -> None:
assert isinstance(action_space, DiscreteActionSpace)
super(QuantileRegressionDeepTDLearning, self).__init__(
training_rounds=training_rounds,
batch_size=batch_size,
exploration_module=exploration_module,
on_policy=on_policy,
is_action_continuous=False,
action_representation_module=action_representation_module,
)
if hidden_dims is None:
hidden_dims = []
self._action_space = action_space
self._discount_factor = discount_factor
self._target_update_freq = target_update_freq
self._soft_update_tau = soft_update_tau
self._num_quantiles = num_quantiles
def make_specified_network() -> QuantileQValueNetwork:
assert hidden_dims is not None
return network_type(
state_dim=state_dim,
action_dim=action_space.n, # pyre-ignore[16]
hidden_dims=hidden_dims,
num_quantiles=num_quantiles,
)
if network_instance is not None:
self._Q: QuantileQValueNetwork = network_instance
assert network_instance.state_dim == state_dim, (
"input state dimension doesn't match network "
"state dimension for QuantileQValueNetwork"
)
assert network_instance.action_dim == action_space.n, (
"input action dimension doesn't match network "
"action dimension for QuantileQValueNetwork"
)
else:
assert hidden_dims is not None
self._Q: QuantileQValueNetwork = make_specified_network()
self._Q_target: QuantileQValueNetwork = copy.deepcopy(self._Q)
self._optimizer = optim.AdamW(
self._Q.parameters(), lr=learning_rate, amsgrad=True
)
def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
self._optimizer.add_param_group({"params": value.parameters()})
self._history_summarization_module = value
def reset(self, action_space: ActionSpace) -> None:
self._action_space = action_space
def act(
self,
subjective_state: SubjectiveState,
available_action_space: ActionSpace,
exploit: bool = False,
) -> Action:
assert isinstance(available_action_space, DiscreteActionSpace)
# Fix the available action space.
with torch.no_grad():
states_repeated = torch.repeat_interleave(
subjective_state.unsqueeze(0),
available_action_space.n,
dim=0,
) # (action_space_size x state_dim)
actions = F.one_hot(torch.arange(0, available_action_space.n)).to(
subjective_state.device
)
# (action_space_size, action_dim)
# instead of using the 'get_q_values' method of the QuantileQValueNetwork,
# we invoke a method from the risk sensitive safety module
q_values = self.safety_module.get_q_values_under_risk_metric(
states_repeated, actions, self._Q
)
exploit_action = torch.argmax(q_values).view((-1))
if exploit:
return exploit_action
return self._exploration_module.act(
subjective_state,
available_action_space,
exploit_action,
q_values,
)
# QR DQN, QR SAC and QR SARSA will implement this differently
@abstractmethod
def _get_next_state_quantiles(
self, batch: TransitionBatch, batch_size: int
) -> torch.Tensor:
pass
# learn quantiles of q value distribution using distribution temporal
# difference learning (specifically, quantile regression)
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
"""
Assume N is the number of quantiles.
- This is the learning update for the quantile q value network which,
for each (state, action) pair, computes the quantile locations
(theta_1(s,a), .. , theta_N(s,a)). The quantiles are fixed to be 1/N.
- The return distribution is represented as: Z(s, a) = (1/N) * sum_{i=1}^N theta_i (s,a),
where (theta_1(s,a), .. , theta_N(s,a)),
which represent the quantile locations, are outouts of the QuantileQValueNetwork.
- Loss function:
sum_{i=1}^N E_{j} [ rho_{tau^*_i}( T theta_j(s',a*) - theta_i(s,a) ) ] - Eq (1)
- tau^*_i is the i-th quantile midpoint ((tau_i + tau_{i-1})/2),
- T is the distributional Bellman operator,
- rho_tau(.) is the asymmetric quantile huber loss function,
- theta_i and theta_j are outputs of the QuantileQValueNetwork,
representing locations of quantiles,
- a* is the greedy action with respect to Q values (computed from the q value
distribution under some risk metric)
See the parameterization in QR DQN paper: https://arxiv.org/pdf/1710.10044.pdf for details.
"""
batch_size = batch.state.shape[0]
"""
Step 1: a forward pass through the quantile network which gives quantile locations,
theta(s,a), for each (state, action) pair
"""
# a forward pass through the quantile network which gives quantile locations
# for each (state, action) pair
quantile_state_action_values = self._Q.get_q_value_distribution(
state_batch=batch.state, action_batch=batch.action
) # shape: (batch_size, num_quantiles)
"""
Step 2: compute Bellman target for each quantile location
- add a dimension to the reward and (1-done) vectors so they
can be broadcasted with the next state quantiles
"""
with torch.no_grad():
quantile_next_state_greedy_action_values = self._get_next_state_quantiles(
batch, batch_size
) * self._discount_factor * (1 - batch.done.float()).unsqueeze(
-1
) + batch.reward.unsqueeze(
-1
)
"""
Step 3: pairwise distributional quantile loss:
T theta_j(s',a*) - theta_i(s,a) for i,j in (1, .. , N)
- output shape: (batch_size, N, N)
"""
pairwise_quantile_loss = quantile_next_state_greedy_action_values.unsqueeze(
2
) - quantile_state_action_values.unsqueeze(1)
# elementwise huber loss smoothes the quantile loss, since it is non-smooth at 0
huber_loss = compute_elementwise_huber_loss(pairwise_quantile_loss)
with torch.no_grad():
asymmetric_weight = torch.abs(
self._Q.quantile_midpoints - (pairwise_quantile_loss < 0).float()
)
"""
# Step 4: compute asymmetric huber loss (also known as the quantile huber loss)
- output shape: (batch_size, N, N)
"""
quantile_huber_loss = asymmetric_weight * huber_loss
"""
Step 5: compute loss to optimize: given pairwise quantile huber loss,
- sum(dim=1) approximates the (sum_{i=1}^N [ .. ]) term in Equation (1),
- mean() takes average over the other quantile dimension (E_j [ .. ]) and over batch
"""
quantile_bellman_loss = quantile_huber_loss.sum(dim=1).mean()
# optimize model (parameters of quantile q network)
self._optimizer.zero_grad()
quantile_bellman_loss.backward()
self._optimizer.step()
# target network update
if (self._training_steps + 1) % self._target_update_freq == 0:
update_target_network(self._Q_target, self._Q, self._soft_update_tau)
return {
"loss": torch.abs(
quantile_state_action_values - quantile_next_state_greedy_action_values
)
.mean()
.item()
}
Classes
class QuantileRegressionDeepTDLearning (state_dim: int, action_space: ActionSpace, on_policy: bool, exploration_module: ExplorationModule, hidden_dims: Optional[List[int]] = None, num_quantiles: int = 10, learning_rate: float = 0.0005, discount_factor: float = 0.99, training_rounds: int = 100, batch_size: int = 128, target_update_freq: int = 10, soft_update_tau: float = 0.05, network_type: Type[QuantileQValueNetwork] = pearl.neural_networks.common.value_networks.QuantileQValueNetwork, network_instance: Optional[QuantileQValueNetwork] = None, action_representation_module: Optional[ActionRepresentationModule] = None)
-
An Abstract Class for Quantile Regression based Deep Temporal Difference learning.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class QuantileRegressionDeepTDLearning(DistributionalPolicyLearner): """ An Abstract Class for Quantile Regression based Deep Temporal Difference learning. """ def __init__( self, state_dim: int, action_space: ActionSpace, on_policy: bool, exploration_module: ExplorationModule, hidden_dims: Optional[List[int]] = None, num_quantiles: int = 10, learning_rate: float = 5 * 0.0001, discount_factor: float = 0.99, training_rounds: int = 100, batch_size: int = 128, target_update_freq: int = 10, soft_update_tau: float = 0.05, # typical value for soft update network_type: Type[ QuantileQValueNetwork ] = QuantileQValueNetwork, # C51 might use a different network type; add that later network_instance: Optional[QuantileQValueNetwork] = None, action_representation_module: Optional[ActionRepresentationModule] = None, ) -> None: assert isinstance(action_space, DiscreteActionSpace) super(QuantileRegressionDeepTDLearning, self).__init__( training_rounds=training_rounds, batch_size=batch_size, exploration_module=exploration_module, on_policy=on_policy, is_action_continuous=False, action_representation_module=action_representation_module, ) if hidden_dims is None: hidden_dims = [] self._action_space = action_space self._discount_factor = discount_factor self._target_update_freq = target_update_freq self._soft_update_tau = soft_update_tau self._num_quantiles = num_quantiles def make_specified_network() -> QuantileQValueNetwork: assert hidden_dims is not None return network_type( state_dim=state_dim, action_dim=action_space.n, # pyre-ignore[16] hidden_dims=hidden_dims, num_quantiles=num_quantiles, ) if network_instance is not None: self._Q: QuantileQValueNetwork = network_instance assert network_instance.state_dim == state_dim, ( "input state dimension doesn't match network " "state dimension for QuantileQValueNetwork" ) assert network_instance.action_dim == action_space.n, ( "input action dimension doesn't match network " "action dimension for QuantileQValueNetwork" ) else: assert hidden_dims is not None self._Q: QuantileQValueNetwork = make_specified_network() self._Q_target: QuantileQValueNetwork = copy.deepcopy(self._Q) self._optimizer = optim.AdamW( self._Q.parameters(), lr=learning_rate, amsgrad=True ) def set_history_summarization_module( self, value: HistorySummarizationModule ) -> None: self._optimizer.add_param_group({"params": value.parameters()}) self._history_summarization_module = value def reset(self, action_space: ActionSpace) -> None: self._action_space = action_space def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, exploit: bool = False, ) -> Action: assert isinstance(available_action_space, DiscreteActionSpace) # Fix the available action space. with torch.no_grad(): states_repeated = torch.repeat_interleave( subjective_state.unsqueeze(0), available_action_space.n, dim=0, ) # (action_space_size x state_dim) actions = F.one_hot(torch.arange(0, available_action_space.n)).to( subjective_state.device ) # (action_space_size, action_dim) # instead of using the 'get_q_values' method of the QuantileQValueNetwork, # we invoke a method from the risk sensitive safety module q_values = self.safety_module.get_q_values_under_risk_metric( states_repeated, actions, self._Q ) exploit_action = torch.argmax(q_values).view((-1)) if exploit: return exploit_action return self._exploration_module.act( subjective_state, available_action_space, exploit_action, q_values, ) # QR DQN, QR SAC and QR SARSA will implement this differently @abstractmethod def _get_next_state_quantiles( self, batch: TransitionBatch, batch_size: int ) -> torch.Tensor: pass # learn quantiles of q value distribution using distribution temporal # difference learning (specifically, quantile regression) def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: """ Assume N is the number of quantiles. - This is the learning update for the quantile q value network which, for each (state, action) pair, computes the quantile locations (theta_1(s,a), .. , theta_N(s,a)). The quantiles are fixed to be 1/N. - The return distribution is represented as: Z(s, a) = (1/N) * sum_{i=1}^N theta_i (s,a), where (theta_1(s,a), .. , theta_N(s,a)), which represent the quantile locations, are outouts of the QuantileQValueNetwork. - Loss function: sum_{i=1}^N E_{j} [ rho_{tau^*_i}( T theta_j(s',a*) - theta_i(s,a) ) ] - Eq (1) - tau^*_i is the i-th quantile midpoint ((tau_i + tau_{i-1})/2), - T is the distributional Bellman operator, - rho_tau(.) is the asymmetric quantile huber loss function, - theta_i and theta_j are outputs of the QuantileQValueNetwork, representing locations of quantiles, - a* is the greedy action with respect to Q values (computed from the q value distribution under some risk metric) See the parameterization in QR DQN paper: https://arxiv.org/pdf/1710.10044.pdf for details. """ batch_size = batch.state.shape[0] """ Step 1: a forward pass through the quantile network which gives quantile locations, theta(s,a), for each (state, action) pair """ # a forward pass through the quantile network which gives quantile locations # for each (state, action) pair quantile_state_action_values = self._Q.get_q_value_distribution( state_batch=batch.state, action_batch=batch.action ) # shape: (batch_size, num_quantiles) """ Step 2: compute Bellman target for each quantile location - add a dimension to the reward and (1-done) vectors so they can be broadcasted with the next state quantiles """ with torch.no_grad(): quantile_next_state_greedy_action_values = self._get_next_state_quantiles( batch, batch_size ) * self._discount_factor * (1 - batch.done.float()).unsqueeze( -1 ) + batch.reward.unsqueeze( -1 ) """ Step 3: pairwise distributional quantile loss: T theta_j(s',a*) - theta_i(s,a) for i,j in (1, .. , N) - output shape: (batch_size, N, N) """ pairwise_quantile_loss = quantile_next_state_greedy_action_values.unsqueeze( 2 ) - quantile_state_action_values.unsqueeze(1) # elementwise huber loss smoothes the quantile loss, since it is non-smooth at 0 huber_loss = compute_elementwise_huber_loss(pairwise_quantile_loss) with torch.no_grad(): asymmetric_weight = torch.abs( self._Q.quantile_midpoints - (pairwise_quantile_loss < 0).float() ) """ # Step 4: compute asymmetric huber loss (also known as the quantile huber loss) - output shape: (batch_size, N, N) """ quantile_huber_loss = asymmetric_weight * huber_loss """ Step 5: compute loss to optimize: given pairwise quantile huber loss, - sum(dim=1) approximates the (sum_{i=1}^N [ .. ]) term in Equation (1), - mean() takes average over the other quantile dimension (E_j [ .. ]) and over batch """ quantile_bellman_loss = quantile_huber_loss.sum(dim=1).mean() # optimize model (parameters of quantile q network) self._optimizer.zero_grad() quantile_bellman_loss.backward() self._optimizer.step() # target network update if (self._training_steps + 1) % self._target_update_freq == 0: update_target_network(self._Q_target, self._Q, self._soft_update_tau) return { "loss": torch.abs( quantile_state_action_values - quantile_next_state_greedy_action_values ) .mean() .item() }
Ancestors
- DistributionalPolicyLearner
- PolicyLearner
- torch.nn.modules.module.Module
- abc.ABC
Subclasses
Methods
def act(self, subjective_state: torch.Tensor, available_action_space: ActionSpace, exploit: bool = False) ‑> torch.Tensor
-
Expand source code
def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, exploit: bool = False, ) -> Action: assert isinstance(available_action_space, DiscreteActionSpace) # Fix the available action space. with torch.no_grad(): states_repeated = torch.repeat_interleave( subjective_state.unsqueeze(0), available_action_space.n, dim=0, ) # (action_space_size x state_dim) actions = F.one_hot(torch.arange(0, available_action_space.n)).to( subjective_state.device ) # (action_space_size, action_dim) # instead of using the 'get_q_values' method of the QuantileQValueNetwork, # we invoke a method from the risk sensitive safety module q_values = self.safety_module.get_q_values_under_risk_metric( states_repeated, actions, self._Q ) exploit_action = torch.argmax(q_values).view((-1)) if exploit: return exploit_action return self._exploration_module.act( subjective_state, available_action_space, exploit_action, q_values, )
def learn_batch(self, batch: TransitionBatch) ‑> Dict[str, Any]
-
Assume N is the number of quantiles.
- This is the learning update for the quantile q value network which, for each (state, action) pair, computes the quantile locations (theta_1(s,a), .. , theta_N(s,a)). The quantiles are fixed to be 1/N.
- The return distribution is represented as: Z(s, a) = (1/N) * sum_{i=1}^N theta_i (s,a), where (theta_1(s,a), .. , theta_N(s,a)), which represent the quantile locations, are outouts of the QuantileQValueNetwork.
-
Loss function: sum_{i=1}^N E_{j} [ rho_{tau^_i}( T theta_j(s',a) - theta_i(s,a) ) ] - Eq (1)
- tau^*i is the i-th quantile midpoint ((tau_i + tau)/2),
- T is the distributional Bellman operator,
- rho_tau(.) is the asymmetric quantile huber loss function,
- theta_i and theta_j are outputs of the QuantileQValueNetwork, representing locations of quantiles,
- a* is the greedy action with respect to Q values (computed from the q value distribution under some risk metric)
See the parameterization in QR DQN paper: https://arxiv.org/pdf/1710.10044.pdf for details.
Expand source code
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: """ Assume N is the number of quantiles. - This is the learning update for the quantile q value network which, for each (state, action) pair, computes the quantile locations (theta_1(s,a), .. , theta_N(s,a)). The quantiles are fixed to be 1/N. - The return distribution is represented as: Z(s, a) = (1/N) * sum_{i=1}^N theta_i (s,a), where (theta_1(s,a), .. , theta_N(s,a)), which represent the quantile locations, are outouts of the QuantileQValueNetwork. - Loss function: sum_{i=1}^N E_{j} [ rho_{tau^*_i}( T theta_j(s',a*) - theta_i(s,a) ) ] - Eq (1) - tau^*_i is the i-th quantile midpoint ((tau_i + tau_{i-1})/2), - T is the distributional Bellman operator, - rho_tau(.) is the asymmetric quantile huber loss function, - theta_i and theta_j are outputs of the QuantileQValueNetwork, representing locations of quantiles, - a* is the greedy action with respect to Q values (computed from the q value distribution under some risk metric) See the parameterization in QR DQN paper: https://arxiv.org/pdf/1710.10044.pdf for details. """ batch_size = batch.state.shape[0] """ Step 1: a forward pass through the quantile network which gives quantile locations, theta(s,a), for each (state, action) pair """ # a forward pass through the quantile network which gives quantile locations # for each (state, action) pair quantile_state_action_values = self._Q.get_q_value_distribution( state_batch=batch.state, action_batch=batch.action ) # shape: (batch_size, num_quantiles) """ Step 2: compute Bellman target for each quantile location - add a dimension to the reward and (1-done) vectors so they can be broadcasted with the next state quantiles """ with torch.no_grad(): quantile_next_state_greedy_action_values = self._get_next_state_quantiles( batch, batch_size ) * self._discount_factor * (1 - batch.done.float()).unsqueeze( -1 ) + batch.reward.unsqueeze( -1 ) """ Step 3: pairwise distributional quantile loss: T theta_j(s',a*) - theta_i(s,a) for i,j in (1, .. , N) - output shape: (batch_size, N, N) """ pairwise_quantile_loss = quantile_next_state_greedy_action_values.unsqueeze( 2 ) - quantile_state_action_values.unsqueeze(1) # elementwise huber loss smoothes the quantile loss, since it is non-smooth at 0 huber_loss = compute_elementwise_huber_loss(pairwise_quantile_loss) with torch.no_grad(): asymmetric_weight = torch.abs( self._Q.quantile_midpoints - (pairwise_quantile_loss < 0).float() ) """ # Step 4: compute asymmetric huber loss (also known as the quantile huber loss) - output shape: (batch_size, N, N) """ quantile_huber_loss = asymmetric_weight * huber_loss """ Step 5: compute loss to optimize: given pairwise quantile huber loss, - sum(dim=1) approximates the (sum_{i=1}^N [ .. ]) term in Equation (1), - mean() takes average over the other quantile dimension (E_j [ .. ]) and over batch """ quantile_bellman_loss = quantile_huber_loss.sum(dim=1).mean() # optimize model (parameters of quantile q network) self._optimizer.zero_grad() quantile_bellman_loss.backward() self._optimizer.step() # target network update if (self._training_steps + 1) % self._target_update_freq == 0: update_target_network(self._Q_target, self._Q, self._soft_update_tau) return { "loss": torch.abs( quantile_state_action_values - quantile_next_state_greedy_action_values ) .mean() .item() }
def set_history_summarization_module(self, value: HistorySummarizationModule) ‑> None
-
Expand source code
def set_history_summarization_module( self, value: HistorySummarizationModule ) -> None: self._optimizer.add_param_group({"params": value.parameters()}) self._history_summarization_module = value
Inherited members