Module pearl.policy_learners.sequential_decision_making.deep_q_learning
Expand source code
from typing import Any, Optional, Tuple
import torch
from pearl.action_representation_modules.action_representation_module import (
    ActionRepresentationModule,
)
from pearl.api.action_space import ActionSpace
from pearl.policy_learners.exploration_modules.common.epsilon_greedy_exploration import (
    EGreedyExploration,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
    ExplorationModule,
)
from pearl.policy_learners.sequential_decision_making.deep_td_learning import (
    DeepTDLearning,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
class DeepQLearning(DeepTDLearning):
    """
    Deep Q Learning Policy Learner
    """
    def __init__(
        self,
        state_dim: int,
        learning_rate: float = 0.001,
        action_space: Optional[ActionSpace] = None,
        exploration_module: Optional[ExplorationModule] = None,
        soft_update_tau: float = 1.0,  # no soft update
        action_representation_module: Optional[ActionRepresentationModule] = None,
        **kwargs: Any,
    ) -> None:
        super(DeepQLearning, self).__init__(
            exploration_module=exploration_module
            if exploration_module is not None
            else EGreedyExploration(0.05),
            on_policy=False,
            state_dim=state_dim,
            action_space=action_space,
            learning_rate=learning_rate,
            soft_update_tau=soft_update_tau,
            action_representation_module=action_representation_module,
            **kwargs,
        )
    @torch.no_grad()
    def _get_next_state_values(
        self, batch: TransitionBatch, batch_size: int
    ) -> torch.Tensor:
        (
            next_state,
            next_available_actions,
            next_unavailable_actions_mask,
        ) = self._prepare_next_state_action_batch(batch)
        assert next_available_actions is not None
        next_state_action_values = self._Q_target.get_q_values(
            next_state, next_available_actions
        ).view(batch_size, -1)
        # (batch_size x action_space_size)
        # Make sure that unavailable actions' Q values are assigned to -inf
        next_state_action_values[next_unavailable_actions_mask] = -float("inf")
        # Torch.max(1) returns value, indices
        return next_state_action_values.max(1)[0]  # (batch_size)
    def _prepare_next_state_action_batch(
        self, batch: TransitionBatch
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        next_state_batch = batch.next_state  # (batch_size x state_dim)
        assert next_state_batch is not None
        next_available_actions_batch = batch.next_available_actions
        # (batch_size x action_space_size x action_dim)
        next_unavailable_actions_mask_batch = batch.next_unavailable_actions_mask
        # (batch_size x action_space_size)
        assert isinstance(self._action_space, DiscreteActionSpace)
        number_of_actions = self._action_space.n
        next_state_batch_repeated = torch.repeat_interleave(
            next_state_batch.unsqueeze(1), number_of_actions, dim=1
        )  # (batch_size x action_space_size x state_dim)
        return (
            next_state_batch_repeated,
            next_available_actions_batch,
            next_unavailable_actions_mask_batch,
        )
Classes
class DeepQLearning (state_dim: int, learning_rate: float = 0.001, action_space: Optional[ActionSpace] = None, exploration_module: Optional[ExplorationModule] = None, soft_update_tau: float = 1.0, action_representation_module: Optional[ActionRepresentationModule] = None, **kwargs: Any)- 
Deep Q Learning Policy Learner
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class DeepQLearning(DeepTDLearning): """ Deep Q Learning Policy Learner """ def __init__( self, state_dim: int, learning_rate: float = 0.001, action_space: Optional[ActionSpace] = None, exploration_module: Optional[ExplorationModule] = None, soft_update_tau: float = 1.0, # no soft update action_representation_module: Optional[ActionRepresentationModule] = None, **kwargs: Any, ) -> None: super(DeepQLearning, self).__init__( exploration_module=exploration_module if exploration_module is not None else EGreedyExploration(0.05), on_policy=False, state_dim=state_dim, action_space=action_space, learning_rate=learning_rate, soft_update_tau=soft_update_tau, action_representation_module=action_representation_module, **kwargs, ) @torch.no_grad() def _get_next_state_values( self, batch: TransitionBatch, batch_size: int ) -> torch.Tensor: ( next_state, next_available_actions, next_unavailable_actions_mask, ) = self._prepare_next_state_action_batch(batch) assert next_available_actions is not None next_state_action_values = self._Q_target.get_q_values( next_state, next_available_actions ).view(batch_size, -1) # (batch_size x action_space_size) # Make sure that unavailable actions' Q values are assigned to -inf next_state_action_values[next_unavailable_actions_mask] = -float("inf") # Torch.max(1) returns value, indices return next_state_action_values.max(1)[0] # (batch_size) def _prepare_next_state_action_batch( self, batch: TransitionBatch ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: next_state_batch = batch.next_state # (batch_size x state_dim) assert next_state_batch is not None next_available_actions_batch = batch.next_available_actions # (batch_size x action_space_size x action_dim) next_unavailable_actions_mask_batch = batch.next_unavailable_actions_mask # (batch_size x action_space_size) assert isinstance(self._action_space, DiscreteActionSpace) number_of_actions = self._action_space.n next_state_batch_repeated = torch.repeat_interleave( next_state_batch.unsqueeze(1), number_of_actions, dim=1 ) # (batch_size x action_space_size x state_dim) return ( next_state_batch_repeated, next_available_actions_batch, next_unavailable_actions_mask_batch, )Ancestors
- DeepTDLearning
 - PolicyLearner
 - torch.nn.modules.module.Module
 - abc.ABC
 
Subclasses
Inherited members