Module pearl.policy_learners.sequential_decision_making.double_dqn

Expand source code
import torch
from pearl.policy_learners.sequential_decision_making.deep_q_learning import (
    DeepQLearning,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace


class DoubleDQN(DeepQLearning):
    """
    Double DQN Policy Learner
    Compare to DQN, it gets a' from Q network and Q(s', a') from target network
    while DQN, get both a' and Q(s', a') from target network

    https://arxiv.org/pdf/1509.06461.pdf
    """

    @torch.no_grad()
    def _get_next_state_values(
        self, batch: TransitionBatch, batch_size: int
    ) -> torch.Tensor:
        next_state_batch = batch.next_state  # (batch_size x state_dim)
        assert next_state_batch is not None

        next_available_actions_batch = batch.next_available_actions
        assert next_available_actions_batch is not None

        # (batch_size x action_space_size x action_dim)
        next_unavailable_actions_mask_batch = (
            batch.next_unavailable_actions_mask
        )  # (batch_size x action_space_size)

        assert isinstance(self._action_space, DiscreteActionSpace)
        number_of_actions = self._action_space.n

        next_state_batch_repeated = torch.repeat_interleave(
            next_state_batch.unsqueeze(1),
            number_of_actions,
            dim=1,
        )  # (batch_size x action_space_size x state_dim)

        next_state_action_values = self._Q.get_q_values(
            next_state_batch_repeated, next_available_actions_batch
        ).view(
            (batch_size, -1)
        )  # (batch_size x action_space_size)
        # Make sure that unavailable actions' Q values are assigned to -inf
        next_state_action_values[next_unavailable_actions_mask_batch] = -float("inf")

        # Torch.max(1) returns value, indices
        next_action_indices = next_state_action_values.max(1)[1]  # (batch_size)
        next_action_batch = next_available_actions_batch[
            torch.arange(next_available_actions_batch.size(0)),
            next_action_indices.squeeze(),
        ]
        return self._Q_target.get_q_values(next_state_batch, next_action_batch)

Classes

class DoubleDQN (state_dim: int, learning_rate: float = 0.001, action_space: Optional[ActionSpace] = None, exploration_module: Optional[ExplorationModule] = None, soft_update_tau: float = 1.0, action_representation_module: Optional[ActionRepresentationModule] = None, **kwargs: Any)

Double DQN Policy Learner Compare to DQN, it gets a' from Q network and Q(s', a') from target network while DQN, get both a' and Q(s', a') from target network

https://arxiv.org/pdf/1509.06461.pdf

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class DoubleDQN(DeepQLearning):
    """
    Double DQN Policy Learner
    Compare to DQN, it gets a' from Q network and Q(s', a') from target network
    while DQN, get both a' and Q(s', a') from target network

    https://arxiv.org/pdf/1509.06461.pdf
    """

    @torch.no_grad()
    def _get_next_state_values(
        self, batch: TransitionBatch, batch_size: int
    ) -> torch.Tensor:
        next_state_batch = batch.next_state  # (batch_size x state_dim)
        assert next_state_batch is not None

        next_available_actions_batch = batch.next_available_actions
        assert next_available_actions_batch is not None

        # (batch_size x action_space_size x action_dim)
        next_unavailable_actions_mask_batch = (
            batch.next_unavailable_actions_mask
        )  # (batch_size x action_space_size)

        assert isinstance(self._action_space, DiscreteActionSpace)
        number_of_actions = self._action_space.n

        next_state_batch_repeated = torch.repeat_interleave(
            next_state_batch.unsqueeze(1),
            number_of_actions,
            dim=1,
        )  # (batch_size x action_space_size x state_dim)

        next_state_action_values = self._Q.get_q_values(
            next_state_batch_repeated, next_available_actions_batch
        ).view(
            (batch_size, -1)
        )  # (batch_size x action_space_size)
        # Make sure that unavailable actions' Q values are assigned to -inf
        next_state_action_values[next_unavailable_actions_mask_batch] = -float("inf")

        # Torch.max(1) returns value, indices
        next_action_indices = next_state_action_values.max(1)[1]  # (batch_size)
        next_action_batch = next_available_actions_batch[
            torch.arange(next_available_actions_batch.size(0)),
            next_action_indices.squeeze(),
        ]
        return self._Q_target.get_q_values(next_state_batch, next_action_batch)

Ancestors

Inherited members