Module pearl.policy_learners.sequential_decision_making.quantile_regression_deep_td_learning

Expand source code
import copy
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Type

import torch
import torch.nn.functional as F
from pearl.action_representation_modules.action_representation_module import (
    ActionRepresentationModule,
)

from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.state import SubjectiveState
from pearl.history_summarization_modules.history_summarization_module import (
    HistorySummarizationModule,
)
from pearl.neural_networks.common.utils import update_target_network
from pearl.neural_networks.common.value_networks import QuantileQValueNetwork
from pearl.policy_learners.exploration_modules.exploration_module import (
    ExplorationModule,
)
from pearl.policy_learners.policy_learner import DistributionalPolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from pearl.safety_modules.risk_sensitive_safety_modules import (  # noqa
    RiskNeutralSafetyModule,  # noqa
)
from pearl.utils.functional_utils.learning.loss_fn_utils import (
    compute_elementwise_huber_loss,
)
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from torch import optim


# TODO: Only support discrete action space problems for now and assumes Gym action space.
class QuantileRegressionDeepTDLearning(DistributionalPolicyLearner):
    """
    An Abstract Class for Quantile Regression based Deep Temporal Difference learning.
    """

    def __init__(
        self,
        state_dim: int,
        action_space: ActionSpace,
        on_policy: bool,
        exploration_module: ExplorationModule,
        hidden_dims: Optional[List[int]] = None,
        num_quantiles: int = 10,
        learning_rate: float = 5 * 0.0001,
        discount_factor: float = 0.99,
        training_rounds: int = 100,
        batch_size: int = 128,
        target_update_freq: int = 10,
        soft_update_tau: float = 0.05,  # typical value for soft update
        network_type: Type[
            QuantileQValueNetwork
        ] = QuantileQValueNetwork,  # C51 might use a different network type; add that later
        network_instance: Optional[QuantileQValueNetwork] = None,
        action_representation_module: Optional[ActionRepresentationModule] = None,
    ) -> None:
        assert isinstance(action_space, DiscreteActionSpace)
        super(QuantileRegressionDeepTDLearning, self).__init__(
            training_rounds=training_rounds,
            batch_size=batch_size,
            exploration_module=exploration_module,
            on_policy=on_policy,
            is_action_continuous=False,
            action_representation_module=action_representation_module,
        )

        if hidden_dims is None:
            hidden_dims = []

        self._action_space = action_space
        self._discount_factor = discount_factor
        self._target_update_freq = target_update_freq
        self._soft_update_tau = soft_update_tau
        self._num_quantiles = num_quantiles

        def make_specified_network() -> QuantileQValueNetwork:
            assert hidden_dims is not None
            return network_type(
                state_dim=state_dim,
                action_dim=action_space.n,  # pyre-ignore[16]
                hidden_dims=hidden_dims,
                num_quantiles=num_quantiles,
            )

        if network_instance is not None:
            self._Q: QuantileQValueNetwork = network_instance
            assert network_instance.state_dim == state_dim, (
                "input state dimension doesn't match network "
                "state dimension for QuantileQValueNetwork"
            )
            assert network_instance.action_dim == action_space.n, (
                "input action dimension doesn't match network "
                "action dimension for QuantileQValueNetwork"
            )
        else:
            assert hidden_dims is not None
            self._Q: QuantileQValueNetwork = make_specified_network()

        self._Q_target: QuantileQValueNetwork = copy.deepcopy(self._Q)
        self._optimizer = optim.AdamW(
            self._Q.parameters(), lr=learning_rate, amsgrad=True
        )

    def set_history_summarization_module(
        self, value: HistorySummarizationModule
    ) -> None:
        self._optimizer.add_param_group({"params": value.parameters()})
        self._history_summarization_module = value

    def reset(self, action_space: ActionSpace) -> None:
        self._action_space = action_space

    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: ActionSpace,
        exploit: bool = False,
    ) -> Action:
        assert isinstance(available_action_space, DiscreteActionSpace)
        # Fix the available action space.
        with torch.no_grad():
            states_repeated = torch.repeat_interleave(
                subjective_state.unsqueeze(0),
                available_action_space.n,
                dim=0,
            )  # (action_space_size x state_dim)

            actions = F.one_hot(torch.arange(0, available_action_space.n)).to(
                subjective_state.device
            )
            # (action_space_size, action_dim)

            # instead of using the 'get_q_values' method of the QuantileQValueNetwork,
            # we invoke a method from the risk sensitive safety module
            q_values = self.safety_module.get_q_values_under_risk_metric(
                states_repeated, actions, self._Q
            )
            exploit_action = torch.argmax(q_values).view((-1))

        if exploit:
            return exploit_action

        return self._exploration_module.act(
            subjective_state,
            available_action_space,
            exploit_action,
            q_values,
        )

    # QR DQN, QR SAC and QR SARSA will implement this differently
    @abstractmethod
    def _get_next_state_quantiles(
        self, batch: TransitionBatch, batch_size: int
    ) -> torch.Tensor:
        pass

    # learn quantiles of q value distribution using distribution temporal
    # difference learning (specifically, quantile regression)
    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        """
        Assume N is the number of quantiles.

        - This is the learning update for the quantile q value network which,
        for each (state, action) pair, computes the quantile locations
        (theta_1(s,a), .. , theta_N(s,a)). The quantiles are fixed to be 1/N.
        - The return distribution is represented as: Z(s, a) = (1/N) * sum_{i=1}^N theta_i (s,a),
        where (theta_1(s,a), .. , theta_N(s,a)),
        which represent the quantile locations, are outouts of the QuantileQValueNetwork.
        - Loss function:
                sum_{i=1}^N E_{j} [ rho_{tau^*_i}( T theta_j(s',a*) - theta_i(s,a) ) ] - Eq (1)

            - tau^*_i is the i-th quantile midpoint ((tau_i + tau_{i-1})/2),
            - T is the distributional Bellman operator,
            - rho_tau(.) is the asymmetric quantile huber loss function,
            - theta_i and theta_j are outputs of the QuantileQValueNetwork,
              representing locations of quantiles,
            - a* is the greedy action with respect to Q values (computed from the q value
              distribution under some risk metric)

        See the parameterization in QR DQN paper: https://arxiv.org/pdf/1710.10044.pdf for details.
        """

        batch_size = batch.state.shape[0]

        """
        Step 1: a forward pass through the quantile network which gives quantile locations,
        theta(s,a), for each (state, action) pair
        """
        # a forward pass through the quantile network which gives quantile locations
        # for each (state, action) pair
        quantile_state_action_values = self._Q.get_q_value_distribution(
            state_batch=batch.state, action_batch=batch.action
        )  # shape: (batch_size, num_quantiles)

        """
        Step 2: compute Bellman target for each quantile location
            - add a dimension to the reward and (1-done) vectors so they
              can be broadcasted with the next state quantiles
        """

        with torch.no_grad():
            quantile_next_state_greedy_action_values = self._get_next_state_quantiles(
                batch, batch_size
            ) * self._discount_factor * (1 - batch.done.float()).unsqueeze(
                -1
            ) + batch.reward.unsqueeze(
                -1
            )

        """
        Step 3: pairwise distributional quantile loss:
        T theta_j(s',a*) - theta_i(s,a) for i,j in (1, .. , N)
            - output shape: (batch_size, N, N)
        """
        pairwise_quantile_loss = quantile_next_state_greedy_action_values.unsqueeze(
            2
        ) - quantile_state_action_values.unsqueeze(1)

        # elementwise huber loss smoothes the quantile loss, since it is non-smooth at 0
        huber_loss = compute_elementwise_huber_loss(pairwise_quantile_loss)

        with torch.no_grad():
            asymmetric_weight = torch.abs(
                self._Q.quantile_midpoints - (pairwise_quantile_loss < 0).float()
            )

        """
        # Step 4: compute asymmetric huber loss (also known as the quantile huber loss)
            - output shape: (batch_size, N, N)
        """
        quantile_huber_loss = asymmetric_weight * huber_loss

        """
        Step 5: compute loss to optimize: given pairwise quantile huber loss,
            - sum(dim=1) approximates the (sum_{i=1}^N [ .. ]) term in Equation (1),
            - mean() takes average over the other quantile dimension (E_j [ .. ]) and over batch
        """
        quantile_bellman_loss = quantile_huber_loss.sum(dim=1).mean()

        # optimize model (parameters of quantile q network)
        self._optimizer.zero_grad()
        quantile_bellman_loss.backward()
        self._optimizer.step()

        # target network update
        if (self._training_steps + 1) % self._target_update_freq == 0:
            update_target_network(self._Q_target, self._Q, self._soft_update_tau)

        return {
            "loss": torch.abs(
                quantile_state_action_values - quantile_next_state_greedy_action_values
            )
            .mean()
            .item()
        }

Classes

class QuantileRegressionDeepTDLearning (state_dim: int, action_space: ActionSpace, on_policy: bool, exploration_module: ExplorationModule, hidden_dims: Optional[List[int]] = None, num_quantiles: int = 10, learning_rate: float = 0.0005, discount_factor: float = 0.99, training_rounds: int = 100, batch_size: int = 128, target_update_freq: int = 10, soft_update_tau: float = 0.05, network_type: Type[QuantileQValueNetwork] = pearl.neural_networks.common.value_networks.QuantileQValueNetwork, network_instance: Optional[QuantileQValueNetwork] = None, action_representation_module: Optional[ActionRepresentationModule] = None)

An Abstract Class for Quantile Regression based Deep Temporal Difference learning.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class QuantileRegressionDeepTDLearning(DistributionalPolicyLearner):
    """
    An Abstract Class for Quantile Regression based Deep Temporal Difference learning.
    """

    def __init__(
        self,
        state_dim: int,
        action_space: ActionSpace,
        on_policy: bool,
        exploration_module: ExplorationModule,
        hidden_dims: Optional[List[int]] = None,
        num_quantiles: int = 10,
        learning_rate: float = 5 * 0.0001,
        discount_factor: float = 0.99,
        training_rounds: int = 100,
        batch_size: int = 128,
        target_update_freq: int = 10,
        soft_update_tau: float = 0.05,  # typical value for soft update
        network_type: Type[
            QuantileQValueNetwork
        ] = QuantileQValueNetwork,  # C51 might use a different network type; add that later
        network_instance: Optional[QuantileQValueNetwork] = None,
        action_representation_module: Optional[ActionRepresentationModule] = None,
    ) -> None:
        assert isinstance(action_space, DiscreteActionSpace)
        super(QuantileRegressionDeepTDLearning, self).__init__(
            training_rounds=training_rounds,
            batch_size=batch_size,
            exploration_module=exploration_module,
            on_policy=on_policy,
            is_action_continuous=False,
            action_representation_module=action_representation_module,
        )

        if hidden_dims is None:
            hidden_dims = []

        self._action_space = action_space
        self._discount_factor = discount_factor
        self._target_update_freq = target_update_freq
        self._soft_update_tau = soft_update_tau
        self._num_quantiles = num_quantiles

        def make_specified_network() -> QuantileQValueNetwork:
            assert hidden_dims is not None
            return network_type(
                state_dim=state_dim,
                action_dim=action_space.n,  # pyre-ignore[16]
                hidden_dims=hidden_dims,
                num_quantiles=num_quantiles,
            )

        if network_instance is not None:
            self._Q: QuantileQValueNetwork = network_instance
            assert network_instance.state_dim == state_dim, (
                "input state dimension doesn't match network "
                "state dimension for QuantileQValueNetwork"
            )
            assert network_instance.action_dim == action_space.n, (
                "input action dimension doesn't match network "
                "action dimension for QuantileQValueNetwork"
            )
        else:
            assert hidden_dims is not None
            self._Q: QuantileQValueNetwork = make_specified_network()

        self._Q_target: QuantileQValueNetwork = copy.deepcopy(self._Q)
        self._optimizer = optim.AdamW(
            self._Q.parameters(), lr=learning_rate, amsgrad=True
        )

    def set_history_summarization_module(
        self, value: HistorySummarizationModule
    ) -> None:
        self._optimizer.add_param_group({"params": value.parameters()})
        self._history_summarization_module = value

    def reset(self, action_space: ActionSpace) -> None:
        self._action_space = action_space

    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: ActionSpace,
        exploit: bool = False,
    ) -> Action:
        assert isinstance(available_action_space, DiscreteActionSpace)
        # Fix the available action space.
        with torch.no_grad():
            states_repeated = torch.repeat_interleave(
                subjective_state.unsqueeze(0),
                available_action_space.n,
                dim=0,
            )  # (action_space_size x state_dim)

            actions = F.one_hot(torch.arange(0, available_action_space.n)).to(
                subjective_state.device
            )
            # (action_space_size, action_dim)

            # instead of using the 'get_q_values' method of the QuantileQValueNetwork,
            # we invoke a method from the risk sensitive safety module
            q_values = self.safety_module.get_q_values_under_risk_metric(
                states_repeated, actions, self._Q
            )
            exploit_action = torch.argmax(q_values).view((-1))

        if exploit:
            return exploit_action

        return self._exploration_module.act(
            subjective_state,
            available_action_space,
            exploit_action,
            q_values,
        )

    # QR DQN, QR SAC and QR SARSA will implement this differently
    @abstractmethod
    def _get_next_state_quantiles(
        self, batch: TransitionBatch, batch_size: int
    ) -> torch.Tensor:
        pass

    # learn quantiles of q value distribution using distribution temporal
    # difference learning (specifically, quantile regression)
    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        """
        Assume N is the number of quantiles.

        - This is the learning update for the quantile q value network which,
        for each (state, action) pair, computes the quantile locations
        (theta_1(s,a), .. , theta_N(s,a)). The quantiles are fixed to be 1/N.
        - The return distribution is represented as: Z(s, a) = (1/N) * sum_{i=1}^N theta_i (s,a),
        where (theta_1(s,a), .. , theta_N(s,a)),
        which represent the quantile locations, are outouts of the QuantileQValueNetwork.
        - Loss function:
                sum_{i=1}^N E_{j} [ rho_{tau^*_i}( T theta_j(s',a*) - theta_i(s,a) ) ] - Eq (1)

            - tau^*_i is the i-th quantile midpoint ((tau_i + tau_{i-1})/2),
            - T is the distributional Bellman operator,
            - rho_tau(.) is the asymmetric quantile huber loss function,
            - theta_i and theta_j are outputs of the QuantileQValueNetwork,
              representing locations of quantiles,
            - a* is the greedy action with respect to Q values (computed from the q value
              distribution under some risk metric)

        See the parameterization in QR DQN paper: https://arxiv.org/pdf/1710.10044.pdf for details.
        """

        batch_size = batch.state.shape[0]

        """
        Step 1: a forward pass through the quantile network which gives quantile locations,
        theta(s,a), for each (state, action) pair
        """
        # a forward pass through the quantile network which gives quantile locations
        # for each (state, action) pair
        quantile_state_action_values = self._Q.get_q_value_distribution(
            state_batch=batch.state, action_batch=batch.action
        )  # shape: (batch_size, num_quantiles)

        """
        Step 2: compute Bellman target for each quantile location
            - add a dimension to the reward and (1-done) vectors so they
              can be broadcasted with the next state quantiles
        """

        with torch.no_grad():
            quantile_next_state_greedy_action_values = self._get_next_state_quantiles(
                batch, batch_size
            ) * self._discount_factor * (1 - batch.done.float()).unsqueeze(
                -1
            ) + batch.reward.unsqueeze(
                -1
            )

        """
        Step 3: pairwise distributional quantile loss:
        T theta_j(s',a*) - theta_i(s,a) for i,j in (1, .. , N)
            - output shape: (batch_size, N, N)
        """
        pairwise_quantile_loss = quantile_next_state_greedy_action_values.unsqueeze(
            2
        ) - quantile_state_action_values.unsqueeze(1)

        # elementwise huber loss smoothes the quantile loss, since it is non-smooth at 0
        huber_loss = compute_elementwise_huber_loss(pairwise_quantile_loss)

        with torch.no_grad():
            asymmetric_weight = torch.abs(
                self._Q.quantile_midpoints - (pairwise_quantile_loss < 0).float()
            )

        """
        # Step 4: compute asymmetric huber loss (also known as the quantile huber loss)
            - output shape: (batch_size, N, N)
        """
        quantile_huber_loss = asymmetric_weight * huber_loss

        """
        Step 5: compute loss to optimize: given pairwise quantile huber loss,
            - sum(dim=1) approximates the (sum_{i=1}^N [ .. ]) term in Equation (1),
            - mean() takes average over the other quantile dimension (E_j [ .. ]) and over batch
        """
        quantile_bellman_loss = quantile_huber_loss.sum(dim=1).mean()

        # optimize model (parameters of quantile q network)
        self._optimizer.zero_grad()
        quantile_bellman_loss.backward()
        self._optimizer.step()

        # target network update
        if (self._training_steps + 1) % self._target_update_freq == 0:
            update_target_network(self._Q_target, self._Q, self._soft_update_tau)

        return {
            "loss": torch.abs(
                quantile_state_action_values - quantile_next_state_greedy_action_values
            )
            .mean()
            .item()
        }

Ancestors

Subclasses

Methods

def act(self, subjective_state: torch.Tensor, available_action_space: ActionSpace, exploit: bool = False) ‑> torch.Tensor
Expand source code
def act(
    self,
    subjective_state: SubjectiveState,
    available_action_space: ActionSpace,
    exploit: bool = False,
) -> Action:
    assert isinstance(available_action_space, DiscreteActionSpace)
    # Fix the available action space.
    with torch.no_grad():
        states_repeated = torch.repeat_interleave(
            subjective_state.unsqueeze(0),
            available_action_space.n,
            dim=0,
        )  # (action_space_size x state_dim)

        actions = F.one_hot(torch.arange(0, available_action_space.n)).to(
            subjective_state.device
        )
        # (action_space_size, action_dim)

        # instead of using the 'get_q_values' method of the QuantileQValueNetwork,
        # we invoke a method from the risk sensitive safety module
        q_values = self.safety_module.get_q_values_under_risk_metric(
            states_repeated, actions, self._Q
        )
        exploit_action = torch.argmax(q_values).view((-1))

    if exploit:
        return exploit_action

    return self._exploration_module.act(
        subjective_state,
        available_action_space,
        exploit_action,
        q_values,
    )
def learn_batch(self, batch: TransitionBatch) ‑> Dict[str, Any]

Assume N is the number of quantiles.

  • This is the learning update for the quantile q value network which, for each (state, action) pair, computes the quantile locations (theta_1(s,a), .. , theta_N(s,a)). The quantiles are fixed to be 1/N.
  • The return distribution is represented as: Z(s, a) = (1/N) * sum_{i=1}^N theta_i (s,a), where (theta_1(s,a), .. , theta_N(s,a)), which represent the quantile locations, are outouts of the QuantileQValueNetwork.
  • Loss function: sum_{i=1}^N E_{j} [ rho_{tau^_i}( T theta_j(s',a) - theta_i(s,a) ) ] - Eq (1)

    • tau^*i is the i-th quantile midpoint ((tau_i + tau)/2),
    • T is the distributional Bellman operator,
    • rho_tau(.) is the asymmetric quantile huber loss function,
    • theta_i and theta_j are outputs of the QuantileQValueNetwork, representing locations of quantiles,
    • a* is the greedy action with respect to Q values (computed from the q value distribution under some risk metric)

See the parameterization in QR DQN paper: https://arxiv.org/pdf/1710.10044.pdf for details.

Expand source code
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
    """
    Assume N is the number of quantiles.

    - This is the learning update for the quantile q value network which,
    for each (state, action) pair, computes the quantile locations
    (theta_1(s,a), .. , theta_N(s,a)). The quantiles are fixed to be 1/N.
    - The return distribution is represented as: Z(s, a) = (1/N) * sum_{i=1}^N theta_i (s,a),
    where (theta_1(s,a), .. , theta_N(s,a)),
    which represent the quantile locations, are outouts of the QuantileQValueNetwork.
    - Loss function:
            sum_{i=1}^N E_{j} [ rho_{tau^*_i}( T theta_j(s',a*) - theta_i(s,a) ) ] - Eq (1)

        - tau^*_i is the i-th quantile midpoint ((tau_i + tau_{i-1})/2),
        - T is the distributional Bellman operator,
        - rho_tau(.) is the asymmetric quantile huber loss function,
        - theta_i and theta_j are outputs of the QuantileQValueNetwork,
          representing locations of quantiles,
        - a* is the greedy action with respect to Q values (computed from the q value
          distribution under some risk metric)

    See the parameterization in QR DQN paper: https://arxiv.org/pdf/1710.10044.pdf for details.
    """

    batch_size = batch.state.shape[0]

    """
    Step 1: a forward pass through the quantile network which gives quantile locations,
    theta(s,a), for each (state, action) pair
    """
    # a forward pass through the quantile network which gives quantile locations
    # for each (state, action) pair
    quantile_state_action_values = self._Q.get_q_value_distribution(
        state_batch=batch.state, action_batch=batch.action
    )  # shape: (batch_size, num_quantiles)

    """
    Step 2: compute Bellman target for each quantile location
        - add a dimension to the reward and (1-done) vectors so they
          can be broadcasted with the next state quantiles
    """

    with torch.no_grad():
        quantile_next_state_greedy_action_values = self._get_next_state_quantiles(
            batch, batch_size
        ) * self._discount_factor * (1 - batch.done.float()).unsqueeze(
            -1
        ) + batch.reward.unsqueeze(
            -1
        )

    """
    Step 3: pairwise distributional quantile loss:
    T theta_j(s',a*) - theta_i(s,a) for i,j in (1, .. , N)
        - output shape: (batch_size, N, N)
    """
    pairwise_quantile_loss = quantile_next_state_greedy_action_values.unsqueeze(
        2
    ) - quantile_state_action_values.unsqueeze(1)

    # elementwise huber loss smoothes the quantile loss, since it is non-smooth at 0
    huber_loss = compute_elementwise_huber_loss(pairwise_quantile_loss)

    with torch.no_grad():
        asymmetric_weight = torch.abs(
            self._Q.quantile_midpoints - (pairwise_quantile_loss < 0).float()
        )

    """
    # Step 4: compute asymmetric huber loss (also known as the quantile huber loss)
        - output shape: (batch_size, N, N)
    """
    quantile_huber_loss = asymmetric_weight * huber_loss

    """
    Step 5: compute loss to optimize: given pairwise quantile huber loss,
        - sum(dim=1) approximates the (sum_{i=1}^N [ .. ]) term in Equation (1),
        - mean() takes average over the other quantile dimension (E_j [ .. ]) and over batch
    """
    quantile_bellman_loss = quantile_huber_loss.sum(dim=1).mean()

    # optimize model (parameters of quantile q network)
    self._optimizer.zero_grad()
    quantile_bellman_loss.backward()
    self._optimizer.step()

    # target network update
    if (self._training_steps + 1) % self._target_update_freq == 0:
        update_target_network(self._Q_target, self._Q, self._soft_update_tau)

    return {
        "loss": torch.abs(
            quantile_state_action_values - quantile_next_state_greedy_action_values
        )
        .mean()
        .item()
    }
def set_history_summarization_module(self, value: HistorySummarizationModule) ‑> None
Expand source code
def set_history_summarization_module(
    self, value: HistorySummarizationModule
) -> None:
    self._optimizer.add_param_group({"params": value.parameters()})
    self._history_summarization_module = value

Inherited members