Module pearl.neural_networks.common.value_networks

This module defines several types of value neural networks.

Expand source code
"""
This module defines several types of value neural networks.
"""


from abc import ABC
from typing import Any, List, Optional

import torch
import torch.nn as nn
from pearl.neural_networks.common.epistemic_neural_networks import Ensemble

from pearl.neural_networks.sequential_decision_making.q_value_network import (
    DistributionalQValueNetwork,
    QValueNetwork,
)
from pearl.utils.functional_utils.learning.extend_state_feature import (
    extend_state_feature_by_available_action_space,
)
from torch import Tensor

from .utils import conv_block, mlp_block


class ValueNetwork(nn.Module, ABC):
    """
    An interface for value neural networks.
    It does not add any required methods to those already present in
    its super classes.
    Its purpose instead is just to serve as an umbrella type for all value networks.
    """


class VanillaValueNetwork(ValueNetwork):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: Optional[List[int]],
        output_dim: int = 1,
        **kwargs: Any,
    ) -> None:
        super(VanillaValueNetwork, self).__init__()
        self._model: nn.Module = mlp_block(
            input_dim=input_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
            **kwargs,
        )

    def forward(self, x: Tensor) -> Tensor:
        return self._model(x)

    # default initialization in linear and conv layers of a F.sequential model is Kaiming
    def xavier_init(self) -> None:
        for layer in self._model:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)


class VanillaCNN(ValueNetwork):
    """
    Vanilla CNN with a convolutional block followed by an mlp block.
    Args:
        input_width: width of the input
        input_height: height of the input
        input_channels_count: number of input channels
        kernel_sizes: list of kernel sizes for the convolutional layers
        output_channels_list: list of number of output channels for each convolutional layer
        strides: list of strides for each layer
        paddings: list of paddings for each layer
        hidden_dims_fully_connected: a list of dimensions of the hidden layers in the mlp
        use_batch_norm_conv: whether to use batch_norm in the convolutional layers
        use_batch_norm_fully_connected: whether to use batch_norm in the fully connected layers
        output_dim: dimension of the output layer
    Returns:
        An nn.Sequential module consisting of a convolutional block followed by an mlp.
    """

    def __init__(
        self,
        input_width: int,
        input_height: int,
        input_channels_count: int,
        kernel_sizes: List[int],
        output_channels_list: List[int],
        strides: List[int],
        paddings: List[int],
        hidden_dims_fully_connected: Optional[
            List[int]
        ] = None,  # hidden dims for fully connected layers
        use_batch_norm_conv: bool = False,
        use_batch_norm_fully_connected: bool = False,
        output_dim: int = 1,  # dimension of the final output
    ) -> None:

        assert (
            len(kernel_sizes)
            == len(output_channels_list)
            == len(strides)
            == len(paddings)
        )
        super(VanillaCNN, self).__init__()

        self._input_channels = input_channels_count
        self._input_height = input_height
        self._input_width = input_width
        self._output_channels = output_channels_list
        self._kernel_sizes = kernel_sizes
        self._strides = strides
        self._paddings = paddings
        if hidden_dims_fully_connected is None:
            self._hidden_dims_fully_connected: List[int] = []
        else:
            self._hidden_dims_fully_connected: List[int] = hidden_dims_fully_connected

        self._use_batch_norm_conv = use_batch_norm_conv
        self._use_batch_norm_fully_connected = use_batch_norm_fully_connected
        self._output_dim = output_dim

        self._model_cnn: nn.Module = conv_block(
            input_channels_count=self._input_channels,
            output_channels_list=self._output_channels,
            kernel_sizes=self._kernel_sizes,
            strides=self._strides,
            paddings=self._paddings,
            use_batch_norm=self._use_batch_norm_conv,
        )

        self._mlp_input_dims: int = self.compute_output_dim_model_cnn()
        self._model_fc: nn.Module = mlp_block(
            input_dim=self._mlp_input_dims,
            hidden_dims=self._hidden_dims_fully_connected,
            output_dim=self._output_dim,
            use_batch_norm=self._use_batch_norm_fully_connected,
        )

    def compute_output_dim_model_cnn(self) -> int:
        dummy_input = torch.zeros(
            1, self._input_channels, self._input_width, self._input_height
        )
        dummy_output_flattened = torch.flatten(
            self._model_cnn(dummy_input), start_dim=1, end_dim=-1
        )
        return dummy_output_flattened.shape[1]

    def forward(self, x: Tensor) -> Tensor:
        out_cnn = self._model_cnn(x)
        out_flattened = torch.flatten(out_cnn, start_dim=1, end_dim=-1)
        out_fc = self._model_fc(out_flattened)
        return out_fc


class CNNQValueNetwork(VanillaCNN):
    """
    A CNN version of state-action value (Q-value) network.
    """

    def __init__(
        self,
        input_width: int,
        input_height: int,
        input_channels_count: int,
        kernel_sizes: List[int],
        output_channels_list: List[int],
        strides: List[int],
        paddings: List[int],
        action_dim: int,
        hidden_dims_fully_connected: Optional[List[int]] = None,
        output_dim: int = 1,
        use_batch_norm_conv: bool = False,
        use_batch_norm_fully_connected: bool = False,
    ) -> None:
        super(CNNQValueNetwork, self).__init__(
            input_width=input_width,
            input_height=input_height,
            input_channels_count=input_channels_count,
            kernel_sizes=kernel_sizes,
            output_channels_list=output_channels_list,
            strides=strides,
            paddings=paddings,
            hidden_dims_fully_connected=hidden_dims_fully_connected,
            use_batch_norm_conv=use_batch_norm_conv,
            use_batch_norm_fully_connected=use_batch_norm_fully_connected,
            output_dim=output_dim,
        )
        # we concatenate actions to state representations in the mlp block of the Q-value network
        self._mlp_input_dims: int = self.compute_output_dim_model_cnn() + action_dim
        self._model_fc: nn.Module = mlp_block(
            input_dim=self._mlp_input_dims,
            hidden_dims=self._hidden_dims_fully_connected,
            output_dim=self._output_dim,
            use_batch_norm=self._use_batch_norm_fully_connected,
        )
        self._action_dim = action_dim

    def forward(self, x: Tensor) -> Tensor:
        return self._model(x)

    def get_q_values(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
        curr_available_actions_batch: Optional[Tensor] = None,
    ) -> Tensor:
        batch_size = state_batch.shape[0]
        state_representation_batch = self._model_cnn(state_batch)
        state_embedding_batch = torch.flatten(
            state_representation_batch, start_dim=1, end_dim=-1
        ).view(batch_size, -1)

        # concatenate actions to state representations and do a forward pass through the mlp_block
        x = torch.cat([state_embedding_batch, action_batch], dim=-1)
        q_values_batch = self._model_fc(x)
        return q_values_batch.view(-1)

    @property
    def action_dim(self) -> int:
        return self._action_dim


class VanillaQValueNetwork(QValueNetwork):
    """
    A vanilla version of state-action value (Q-value) network.
    It leverages the vanilla implementation of value networks by
    using the state-action pair as the input for the value network.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: List[int],
        output_dim: int,
        use_layer_norm: bool = False,
    ) -> None:
        super(VanillaQValueNetwork, self).__init__()
        self._state_dim: int = state_dim
        self._action_dim: int = action_dim
        self._model: nn.Module = mlp_block(
            input_dim=state_dim + action_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
            use_layer_norm=use_layer_norm,
        )

    def forward(self, x: Tensor) -> Tensor:
        return self._model(x)

    def get_q_values(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
        curr_available_actions_batch: Optional[Tensor] = None,
    ) -> Tensor:
        x = torch.cat([state_batch, action_batch], dim=-1)
        return self.forward(x).view(-1)

    @property
    def state_dim(self) -> int:
        return self._state_dim

    @property
    def action_dim(self) -> int:
        return self._action_dim


class QuantileQValueNetwork(DistributionalQValueNetwork):
    """
    A quantile version of state-action value (Q-value) network.
    For each (state, action) input pairs,
    it returns theta(s,a), the locations of quantiles which parameterize the Q value distribution.

    See the parameterization in QR DQN paper: https://arxiv.org/pdf/1710.10044.pdf for more details.

    Assume N is the number of quantiles.
    1) For this parameterization, the quantiles are fixed (1/N),
       while the quantile locations, theta(s,a), are learned.
    2) The return distribution is represented as: Z(s, a) = (1/N) * sum_{i=1}^N theta_i (s,a),
       where (theta_1(s,a), .. , theta_N(s,a)),
    which represent the quantile locations, are outouts of the QuantileQValueNetwork.

    Args:
        num_quantiles: the number of quantiles N, used to approximate the value distribution.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: List[int],
        num_quantiles: int,
        use_layer_norm: bool = False,
    ) -> None:
        super(QuantileQValueNetwork, self).__init__()

        self._model: nn.Module = mlp_block(
            input_dim=state_dim + action_dim,
            hidden_dims=hidden_dims,
            output_dim=num_quantiles,
            use_layer_norm=use_layer_norm,
        )

        self._state_dim: int = state_dim
        self._action_dim: int = action_dim
        self._num_quantiles: int = num_quantiles
        self.register_buffer(
            "_quantiles", torch.arange(0, self._num_quantiles + 1) / self._num_quantiles
        )
        self.register_buffer(
            "_quantile_midpoints",
            ((self._quantiles[1:] + self._quantiles[:-1]) / 2)
            .unsqueeze(0)
            .unsqueeze(0),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self._model(x)

    def get_q_value_distribution(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
    ) -> Tensor:

        x = torch.cat([state_batch, action_batch], dim=-1)
        return self.forward(x)

    @property
    def quantiles(self) -> Tensor:
        return self._quantiles

    @property
    def quantile_midpoints(self) -> Tensor:
        return self._quantile_midpoints

    @property
    def num_quantiles(self) -> int:
        return self._num_quantiles

    @property
    def state_dim(self) -> int:
        return self._state_dim

    @property
    def action_dim(self) -> int:
        return self._action_dim


class DuelingQValueNetwork(QValueNetwork):
    """
    Dueling architecture consists of state architecture, value architecture,
    and advantage architecture.

    The architecture is as follows:
    state --> state_arch -----> value_arch --> value(s)-----------------------\
                                |                                              ---> add --> Q(s,a)
    action ------------concat-> advantage_arch --> advantage(s, a)--- -mean --/
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: List[int],
        output_dim: int,
        value_hidden_dims: Optional[List[int]] = None,
        advantage_hidden_dims: Optional[List[int]] = None,
        state_hidden_dims: Optional[List[int]] = None,
    ) -> None:
        super(DuelingQValueNetwork, self).__init__()
        self._state_dim: int = state_dim
        self._action_dim: int = action_dim

        # state architecture
        self.state_arch = VanillaValueNetwork(
            input_dim=state_dim,
            hidden_dims=hidden_dims if state_hidden_dims is None else state_hidden_dims,
            output_dim=hidden_dims[-1],
        )

        # value architecture
        self.value_arch = VanillaValueNetwork(
            input_dim=hidden_dims[-1],  # same as state_arch output dim
            hidden_dims=hidden_dims if value_hidden_dims is None else value_hidden_dims,
            output_dim=output_dim,  # output_dim=1
        )

        # advantage architecture
        self.advantage_arch = VanillaValueNetwork(
            input_dim=hidden_dims[-1] + action_dim,  # state_arch out dim + action_dim
            hidden_dims=hidden_dims
            if advantage_hidden_dims is None
            else advantage_hidden_dims,
            output_dim=output_dim,  # output_dim=1
        )

    @property
    def state_dim(self) -> int:
        return self._state_dim

    @property
    def action_dim(self) -> int:
        return self._action_dim

    def forward(self, state: Tensor, action: Tensor) -> Tensor:
        assert state.shape[-1] == self.state_dim
        assert action.shape[-1] == self.action_dim

        # state feature architecture : state --> feature
        state_features = self.state_arch(
            state
        )  # shape: (?, state_dim); state_dim is the output dimension of state_arch mlp

        # value architecture : feature --> value
        state_value = self.value_arch(state_features)  # shape: (batch_size)

        # advantage architecture : [state feature, actions] --> advantage
        state_action_features = torch.cat(
            (state_features, action), dim=-1
        )  # shape: (?, state_dim + action_dim)

        advantage = self.advantage_arch(state_action_features)
        advantage_mean = torch.mean(
            advantage, dim=-2, keepdim=True
        )  # -2 is dimension denoting number of actions
        return state_value + advantage - advantage_mean

    def get_q_values(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
        curr_available_actions_batch: Optional[Tensor] = None,
    ) -> Tensor:
        """
        Args:
            batch of states: (batch_size, state_dim)
            batch of actions: (batch_size, action_dim)
            (Optional) batch of available actions (one set of available actions per state):
                    (batch_size, available_action_space_size, action_dim)

            In DUELING_DQN, logic for use with td learning (deep_td_learning)
            a) when curr_available_actions_batch is None, we do a forward pass from Q network
               in this case, the action batch will be the batch of all available actions
               doing a forward pass with mean subtraction is correct

            b) when curr_available_actions_batch is not None,
               extend the state_batch tensor to include available actions,
               that is, state_batch: (batch_size, state_dim)
               --> (batch_size, available_action_space_size, state_dim)
               then, do a forward pass from Q network to calculate
               q-values for (state, all available actions),
               followed by q values for given (state, action) pair in the batch

        TODO: assumes a gym environment interface with fixed action space, change it with masking
        """

        if curr_available_actions_batch is None:
            return self.forward(state_batch, action_batch).view(-1)
        else:
            # calculate the q value of all available actions
            state_repeated_batch = extend_state_feature_by_available_action_space(
                state_batch=state_batch,
                curr_available_actions_batch=curr_available_actions_batch,
            )  # shape: (batch_size, available_action_space_size, state_dim)

            # collect Q values of a state and all available actions
            values_state_available_actions = self.forward(
                state_repeated_batch, curr_available_actions_batch
            )  # shape: (batch_size, available_action_space_size, action_dim)

            # gather only the q value of the action that we are interested in.
            action_idx = (
                torch.argmax(action_batch, dim=1).unsqueeze(-1).unsqueeze(-1)
            )  # one_hot to decimal

            # q value of (state, action) pair of interest
            state_action_values = torch.gather(
                values_state_available_actions, 1, action_idx
            ).view(
                -1
            )  # shape: (batch_size)
        return state_action_values


"""
One can make VanillaValueNetwork to be a special case of TwoTowerQValueNetwork by initializing
linear layers to be an identity map and stopping gradients. This however would be too complex.
"""


class TwoTowerNetwork(QValueNetwork):
    def __init__(
        self,
        state_input_dim: int,
        action_input_dim: int,
        state_output_dim: int,
        action_output_dim: int,
        state_hidden_dims: Optional[List[int]],
        action_hidden_dims: Optional[List[int]],
        hidden_dims: Optional[List[int]],
        output_dim: int = 1,
    ) -> None:

        super(TwoTowerNetwork, self).__init__()

        """
        Input: batch of state, batch of action. Output: batch of Q-values for (s,a) pairs
        The two tower archtecture is as follows:
        state ----> state_feature
                            | concat ----> Q(s,a)
        action ----> action_feature
        """
        self._state_input_dim = state_input_dim
        self._action_input_dim = action_input_dim
        self._state_features = VanillaValueNetwork(
            input_dim=state_input_dim,
            hidden_dims=state_hidden_dims,
            output_dim=state_output_dim,
        )
        self._state_features.xavier_init()
        self._action_features = VanillaValueNetwork(
            input_dim=action_input_dim,
            hidden_dims=action_hidden_dims,
            output_dim=action_output_dim,
        )
        self._action_features.xavier_init()
        self._interaction_features = VanillaValueNetwork(
            input_dim=state_output_dim + action_output_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
        )
        self._interaction_features.xavier_init()

    def forward(self, state_action: Tensor) -> Tensor:
        state = state_action[..., : self._state_input_dim]
        action = state_action[..., self._state_input_dim :]
        output = self.get_q_values(state_batch=state, action_batch=action)
        return output

    def get_q_values(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
        curr_available_actions_batch: Optional[Tensor] = None,
    ) -> Tensor:
        state_batch_features = self._state_features.forward(state_batch)
        """ this might need to be done in tensor_based_replay_buffer """
        action_batch_features = self._action_features.forward(
            action_batch.to(torch.get_default_dtype())
        )

        x = torch.cat([state_batch_features, action_batch_features], dim=-1)
        return self._interaction_features.forward(x).view(-1)  # (batch_size)

    @property
    def state_dim(self) -> int:
        return self._state_input_dim

    @property
    def action_dim(self) -> int:
        return self._action_input_dim


"""
With the same initialization parameters as the VanillaQValue Network, i.e. without
specifying the state_output_dims and/or action_outout_dims, we still add a linear layer to
extract state and/or action features.
"""


class TwoTowerQValueNetwork(TwoTowerNetwork):
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: List[int],
        output_dim: int = 1,
        state_output_dim: Optional[int] = None,
        action_output_dim: Optional[int] = None,
        state_hidden_dims: Optional[List[int]] = None,
        action_hidden_dims: Optional[List[int]] = None,
    ) -> None:

        super().__init__(
            state_input_dim=state_dim,
            action_input_dim=action_dim,
            state_output_dim=state_dim
            if state_output_dim is None
            else state_output_dim,
            action_output_dim=action_dim
            if action_output_dim is None
            else action_output_dim,
            state_hidden_dims=[] if state_hidden_dims is None else state_hidden_dims,
            action_hidden_dims=[] if action_hidden_dims is None else action_hidden_dims,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
        )


class EnsembleQValueNetwork(QValueNetwork):
    r"""A Q-value network that uses the `Ensemble` model."""

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: Optional[List[int]],
        output_dim: int,
        ensemble_size: int,
        prior_scale: float = 1.0,
    ) -> None:
        super(EnsembleQValueNetwork, self).__init__()
        self._state_dim = state_dim
        self._action_dim = action_dim
        self._model = Ensemble(
            input_dim=state_dim + action_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
            ensemble_size=ensemble_size,
            prior_scale=prior_scale,
        )

    @property
    def ensemble_size(self) -> int:
        return self._model.ensemble_size

    def resample_epistemic_index(self) -> None:
        r"""Resamples the epistemic index of the underlying model."""
        self._model._resample_epistemic_index()

    def forward(
        self, x: Tensor, z: Optional[Tensor] = None, persistent: bool = False
    ) -> Tensor:
        return self._model(x, z=z, persistent=persistent)

    def get_q_values(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
        curr_available_actions_batch: Optional[Tensor] = None,
        z: Optional[Tensor] = None,
        persistent: bool = False,
    ) -> Tensor:
        x = torch.cat([state_batch, action_batch], dim=-1)
        return self.forward(x, z=z, persistent=persistent).view(-1)

    @property
    def state_dim(self) -> int:
        return self._state_input_dim

    @property
    def action_dim(self) -> int:
        return self._action_input_dim

Classes

class CNNQValueNetwork (input_width: int, input_height: int, input_channels_count: int, kernel_sizes: List[int], output_channels_list: List[int], strides: List[int], paddings: List[int], action_dim: int, hidden_dims_fully_connected: Optional[List[int]] = None, output_dim: int = 1, use_batch_norm_conv: bool = False, use_batch_norm_fully_connected: bool = False)

A CNN version of state-action value (Q-value) network.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class CNNQValueNetwork(VanillaCNN):
    """
    A CNN version of state-action value (Q-value) network.
    """

    def __init__(
        self,
        input_width: int,
        input_height: int,
        input_channels_count: int,
        kernel_sizes: List[int],
        output_channels_list: List[int],
        strides: List[int],
        paddings: List[int],
        action_dim: int,
        hidden_dims_fully_connected: Optional[List[int]] = None,
        output_dim: int = 1,
        use_batch_norm_conv: bool = False,
        use_batch_norm_fully_connected: bool = False,
    ) -> None:
        super(CNNQValueNetwork, self).__init__(
            input_width=input_width,
            input_height=input_height,
            input_channels_count=input_channels_count,
            kernel_sizes=kernel_sizes,
            output_channels_list=output_channels_list,
            strides=strides,
            paddings=paddings,
            hidden_dims_fully_connected=hidden_dims_fully_connected,
            use_batch_norm_conv=use_batch_norm_conv,
            use_batch_norm_fully_connected=use_batch_norm_fully_connected,
            output_dim=output_dim,
        )
        # we concatenate actions to state representations in the mlp block of the Q-value network
        self._mlp_input_dims: int = self.compute_output_dim_model_cnn() + action_dim
        self._model_fc: nn.Module = mlp_block(
            input_dim=self._mlp_input_dims,
            hidden_dims=self._hidden_dims_fully_connected,
            output_dim=self._output_dim,
            use_batch_norm=self._use_batch_norm_fully_connected,
        )
        self._action_dim = action_dim

    def forward(self, x: Tensor) -> Tensor:
        return self._model(x)

    def get_q_values(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
        curr_available_actions_batch: Optional[Tensor] = None,
    ) -> Tensor:
        batch_size = state_batch.shape[0]
        state_representation_batch = self._model_cnn(state_batch)
        state_embedding_batch = torch.flatten(
            state_representation_batch, start_dim=1, end_dim=-1
        ).view(batch_size, -1)

        # concatenate actions to state representations and do a forward pass through the mlp_block
        x = torch.cat([state_embedding_batch, action_batch], dim=-1)
        q_values_batch = self._model_fc(x)
        return q_values_batch.view(-1)

    @property
    def action_dim(self) -> int:
        return self._action_dim

Ancestors

Instance variables

var action_dim : int
Expand source code
@property
def action_dim(self) -> int:
    return self._action_dim

Methods

def get_q_values(self, state_batch: torch.Tensor, action_batch: torch.Tensor, curr_available_actions_batch: Optional[torch.Tensor] = None) ‑> torch.Tensor
Expand source code
def get_q_values(
    self,
    state_batch: Tensor,
    action_batch: Tensor,
    curr_available_actions_batch: Optional[Tensor] = None,
) -> Tensor:
    batch_size = state_batch.shape[0]
    state_representation_batch = self._model_cnn(state_batch)
    state_embedding_batch = torch.flatten(
        state_representation_batch, start_dim=1, end_dim=-1
    ).view(batch_size, -1)

    # concatenate actions to state representations and do a forward pass through the mlp_block
    x = torch.cat([state_embedding_batch, action_batch], dim=-1)
    q_values_batch = self._model_fc(x)
    return q_values_batch.view(-1)

Inherited members

class DuelingQValueNetwork (state_dim: int, action_dim: int, hidden_dims: List[int], output_dim: int, value_hidden_dims: Optional[List[int]] = None, advantage_hidden_dims: Optional[List[int]] = None, state_hidden_dims: Optional[List[int]] = None)

Dueling architecture consists of state architecture, value architecture, and advantage architecture.

The architecture is as follows: state –> state_arch -----> value_arch –> value(s)----------------------- | —> add –> Q(s,a) action ------------concat-> advantage_arch –> advantage(s, a)— -mean –/

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class DuelingQValueNetwork(QValueNetwork):
    """
    Dueling architecture consists of state architecture, value architecture,
    and advantage architecture.

    The architecture is as follows:
    state --> state_arch -----> value_arch --> value(s)-----------------------\
                                |                                              ---> add --> Q(s,a)
    action ------------concat-> advantage_arch --> advantage(s, a)--- -mean --/
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: List[int],
        output_dim: int,
        value_hidden_dims: Optional[List[int]] = None,
        advantage_hidden_dims: Optional[List[int]] = None,
        state_hidden_dims: Optional[List[int]] = None,
    ) -> None:
        super(DuelingQValueNetwork, self).__init__()
        self._state_dim: int = state_dim
        self._action_dim: int = action_dim

        # state architecture
        self.state_arch = VanillaValueNetwork(
            input_dim=state_dim,
            hidden_dims=hidden_dims if state_hidden_dims is None else state_hidden_dims,
            output_dim=hidden_dims[-1],
        )

        # value architecture
        self.value_arch = VanillaValueNetwork(
            input_dim=hidden_dims[-1],  # same as state_arch output dim
            hidden_dims=hidden_dims if value_hidden_dims is None else value_hidden_dims,
            output_dim=output_dim,  # output_dim=1
        )

        # advantage architecture
        self.advantage_arch = VanillaValueNetwork(
            input_dim=hidden_dims[-1] + action_dim,  # state_arch out dim + action_dim
            hidden_dims=hidden_dims
            if advantage_hidden_dims is None
            else advantage_hidden_dims,
            output_dim=output_dim,  # output_dim=1
        )

    @property
    def state_dim(self) -> int:
        return self._state_dim

    @property
    def action_dim(self) -> int:
        return self._action_dim

    def forward(self, state: Tensor, action: Tensor) -> Tensor:
        assert state.shape[-1] == self.state_dim
        assert action.shape[-1] == self.action_dim

        # state feature architecture : state --> feature
        state_features = self.state_arch(
            state
        )  # shape: (?, state_dim); state_dim is the output dimension of state_arch mlp

        # value architecture : feature --> value
        state_value = self.value_arch(state_features)  # shape: (batch_size)

        # advantage architecture : [state feature, actions] --> advantage
        state_action_features = torch.cat(
            (state_features, action), dim=-1
        )  # shape: (?, state_dim + action_dim)

        advantage = self.advantage_arch(state_action_features)
        advantage_mean = torch.mean(
            advantage, dim=-2, keepdim=True
        )  # -2 is dimension denoting number of actions
        return state_value + advantage - advantage_mean

    def get_q_values(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
        curr_available_actions_batch: Optional[Tensor] = None,
    ) -> Tensor:
        """
        Args:
            batch of states: (batch_size, state_dim)
            batch of actions: (batch_size, action_dim)
            (Optional) batch of available actions (one set of available actions per state):
                    (batch_size, available_action_space_size, action_dim)

            In DUELING_DQN, logic for use with td learning (deep_td_learning)
            a) when curr_available_actions_batch is None, we do a forward pass from Q network
               in this case, the action batch will be the batch of all available actions
               doing a forward pass with mean subtraction is correct

            b) when curr_available_actions_batch is not None,
               extend the state_batch tensor to include available actions,
               that is, state_batch: (batch_size, state_dim)
               --> (batch_size, available_action_space_size, state_dim)
               then, do a forward pass from Q network to calculate
               q-values for (state, all available actions),
               followed by q values for given (state, action) pair in the batch

        TODO: assumes a gym environment interface with fixed action space, change it with masking
        """

        if curr_available_actions_batch is None:
            return self.forward(state_batch, action_batch).view(-1)
        else:
            # calculate the q value of all available actions
            state_repeated_batch = extend_state_feature_by_available_action_space(
                state_batch=state_batch,
                curr_available_actions_batch=curr_available_actions_batch,
            )  # shape: (batch_size, available_action_space_size, state_dim)

            # collect Q values of a state and all available actions
            values_state_available_actions = self.forward(
                state_repeated_batch, curr_available_actions_batch
            )  # shape: (batch_size, available_action_space_size, action_dim)

            # gather only the q value of the action that we are interested in.
            action_idx = (
                torch.argmax(action_batch, dim=1).unsqueeze(-1).unsqueeze(-1)
            )  # one_hot to decimal

            # q value of (state, action) pair of interest
            state_action_values = torch.gather(
                values_state_available_actions, 1, action_idx
            ).view(
                -1
            )  # shape: (batch_size)
        return state_action_values

Ancestors

Methods

def forward(self, state: torch.Tensor, action: torch.Tensor) ‑> torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, state: Tensor, action: Tensor) -> Tensor:
    assert state.shape[-1] == self.state_dim
    assert action.shape[-1] == self.action_dim

    # state feature architecture : state --> feature
    state_features = self.state_arch(
        state
    )  # shape: (?, state_dim); state_dim is the output dimension of state_arch mlp

    # value architecture : feature --> value
    state_value = self.value_arch(state_features)  # shape: (batch_size)

    # advantage architecture : [state feature, actions] --> advantage
    state_action_features = torch.cat(
        (state_features, action), dim=-1
    )  # shape: (?, state_dim + action_dim)

    advantage = self.advantage_arch(state_action_features)
    advantage_mean = torch.mean(
        advantage, dim=-2, keepdim=True
    )  # -2 is dimension denoting number of actions
    return state_value + advantage - advantage_mean
def get_q_values(self, state_batch: torch.Tensor, action_batch: torch.Tensor, curr_available_actions_batch: Optional[torch.Tensor] = None) ‑> torch.Tensor

Args

batch of states: (batch_size, state_dim) batch of actions: (batch_size, action_dim) (Optional) batch of available actions (one set of available actions per state): (batch_size, available_action_space_size, action_dim)

In DUELING_DQN, logic for use with td learning (deep_td_learning) a) when curr_available_actions_batch is None, we do a forward pass from Q network in this case, the action batch will be the batch of all available actions doing a forward pass with mean subtraction is correct

b) when curr_available_actions_batch is not None, extend the state_batch tensor to include available actions, that is, state_batch: (batch_size, state_dim) –> (batch_size, available_action_space_size, state_dim) then, do a forward pass from Q network to calculate q-values for (state, all available actions), followed by q values for given (state, action) pair in the batch TODO: assumes a gym environment interface with fixed action space, change it with masking

Expand source code
def get_q_values(
    self,
    state_batch: Tensor,
    action_batch: Tensor,
    curr_available_actions_batch: Optional[Tensor] = None,
) -> Tensor:
    """
    Args:
        batch of states: (batch_size, state_dim)
        batch of actions: (batch_size, action_dim)
        (Optional) batch of available actions (one set of available actions per state):
                (batch_size, available_action_space_size, action_dim)

        In DUELING_DQN, logic for use with td learning (deep_td_learning)
        a) when curr_available_actions_batch is None, we do a forward pass from Q network
           in this case, the action batch will be the batch of all available actions
           doing a forward pass with mean subtraction is correct

        b) when curr_available_actions_batch is not None,
           extend the state_batch tensor to include available actions,
           that is, state_batch: (batch_size, state_dim)
           --> (batch_size, available_action_space_size, state_dim)
           then, do a forward pass from Q network to calculate
           q-values for (state, all available actions),
           followed by q values for given (state, action) pair in the batch

    TODO: assumes a gym environment interface with fixed action space, change it with masking
    """

    if curr_available_actions_batch is None:
        return self.forward(state_batch, action_batch).view(-1)
    else:
        # calculate the q value of all available actions
        state_repeated_batch = extend_state_feature_by_available_action_space(
            state_batch=state_batch,
            curr_available_actions_batch=curr_available_actions_batch,
        )  # shape: (batch_size, available_action_space_size, state_dim)

        # collect Q values of a state and all available actions
        values_state_available_actions = self.forward(
            state_repeated_batch, curr_available_actions_batch
        )  # shape: (batch_size, available_action_space_size, action_dim)

        # gather only the q value of the action that we are interested in.
        action_idx = (
            torch.argmax(action_batch, dim=1).unsqueeze(-1).unsqueeze(-1)
        )  # one_hot to decimal

        # q value of (state, action) pair of interest
        state_action_values = torch.gather(
            values_state_available_actions, 1, action_idx
        ).view(
            -1
        )  # shape: (batch_size)
    return state_action_values

Inherited members

class EnsembleQValueNetwork (state_dim: int, action_dim: int, hidden_dims: Optional[List[int]], output_dim: int, ensemble_size: int, prior_scale: float = 1.0)

A Q-value network that uses the Ensemble model.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class EnsembleQValueNetwork(QValueNetwork):
    r"""A Q-value network that uses the `Ensemble` model."""

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: Optional[List[int]],
        output_dim: int,
        ensemble_size: int,
        prior_scale: float = 1.0,
    ) -> None:
        super(EnsembleQValueNetwork, self).__init__()
        self._state_dim = state_dim
        self._action_dim = action_dim
        self._model = Ensemble(
            input_dim=state_dim + action_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
            ensemble_size=ensemble_size,
            prior_scale=prior_scale,
        )

    @property
    def ensemble_size(self) -> int:
        return self._model.ensemble_size

    def resample_epistemic_index(self) -> None:
        r"""Resamples the epistemic index of the underlying model."""
        self._model._resample_epistemic_index()

    def forward(
        self, x: Tensor, z: Optional[Tensor] = None, persistent: bool = False
    ) -> Tensor:
        return self._model(x, z=z, persistent=persistent)

    def get_q_values(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
        curr_available_actions_batch: Optional[Tensor] = None,
        z: Optional[Tensor] = None,
        persistent: bool = False,
    ) -> Tensor:
        x = torch.cat([state_batch, action_batch], dim=-1)
        return self.forward(x, z=z, persistent=persistent).view(-1)

    @property
    def state_dim(self) -> int:
        return self._state_input_dim

    @property
    def action_dim(self) -> int:
        return self._action_input_dim

Ancestors

Instance variables

var ensemble_size : int
Expand source code
@property
def ensemble_size(self) -> int:
    return self._model.ensemble_size

Methods

def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None, persistent: bool = False) ‑> torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(
    self, x: Tensor, z: Optional[Tensor] = None, persistent: bool = False
) -> Tensor:
    return self._model(x, z=z, persistent=persistent)
def resample_epistemic_index(self) ‑> None

Resamples the epistemic index of the underlying model.

Expand source code
def resample_epistemic_index(self) -> None:
    r"""Resamples the epistemic index of the underlying model."""
    self._model._resample_epistemic_index()

Inherited members

class QuantileQValueNetwork (state_dim: int, action_dim: int, hidden_dims: List[int], num_quantiles: int, use_layer_norm: bool = False)

A quantile version of state-action value (Q-value) network. For each (state, action) input pairs, it returns theta(s,a), the locations of quantiles which parameterize the Q value distribution.

See the parameterization in QR DQN paper: https://arxiv.org/pdf/1710.10044.pdf for more details.

Assume N is the number of quantiles. 1) For this parameterization, the quantiles are fixed (1/N), while the quantile locations, theta(s,a), are learned. 2) The return distribution is represented as: Z(s, a) = (1/N) * sum_{i=1}^N theta_i (s,a), where (theta_1(s,a), .. , theta_N(s,a)), which represent the quantile locations, are outouts of the QuantileQValueNetwork.

Args

num_quantiles
the number of quantiles N, used to approximate the value distribution.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class QuantileQValueNetwork(DistributionalQValueNetwork):
    """
    A quantile version of state-action value (Q-value) network.
    For each (state, action) input pairs,
    it returns theta(s,a), the locations of quantiles which parameterize the Q value distribution.

    See the parameterization in QR DQN paper: https://arxiv.org/pdf/1710.10044.pdf for more details.

    Assume N is the number of quantiles.
    1) For this parameterization, the quantiles are fixed (1/N),
       while the quantile locations, theta(s,a), are learned.
    2) The return distribution is represented as: Z(s, a) = (1/N) * sum_{i=1}^N theta_i (s,a),
       where (theta_1(s,a), .. , theta_N(s,a)),
    which represent the quantile locations, are outouts of the QuantileQValueNetwork.

    Args:
        num_quantiles: the number of quantiles N, used to approximate the value distribution.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: List[int],
        num_quantiles: int,
        use_layer_norm: bool = False,
    ) -> None:
        super(QuantileQValueNetwork, self).__init__()

        self._model: nn.Module = mlp_block(
            input_dim=state_dim + action_dim,
            hidden_dims=hidden_dims,
            output_dim=num_quantiles,
            use_layer_norm=use_layer_norm,
        )

        self._state_dim: int = state_dim
        self._action_dim: int = action_dim
        self._num_quantiles: int = num_quantiles
        self.register_buffer(
            "_quantiles", torch.arange(0, self._num_quantiles + 1) / self._num_quantiles
        )
        self.register_buffer(
            "_quantile_midpoints",
            ((self._quantiles[1:] + self._quantiles[:-1]) / 2)
            .unsqueeze(0)
            .unsqueeze(0),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self._model(x)

    def get_q_value_distribution(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
    ) -> Tensor:

        x = torch.cat([state_batch, action_batch], dim=-1)
        return self.forward(x)

    @property
    def quantiles(self) -> Tensor:
        return self._quantiles

    @property
    def quantile_midpoints(self) -> Tensor:
        return self._quantile_midpoints

    @property
    def num_quantiles(self) -> int:
        return self._num_quantiles

    @property
    def state_dim(self) -> int:
        return self._state_dim

    @property
    def action_dim(self) -> int:
        return self._action_dim

Ancestors

Methods

def forward(self, x: torch.Tensor) ‑> torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x: Tensor) -> Tensor:
    return self._model(x)

Inherited members

class TwoTowerNetwork (state_input_dim: int, action_input_dim: int, state_output_dim: int, action_output_dim: int, state_hidden_dims: Optional[List[int]], action_hidden_dims: Optional[List[int]], hidden_dims: Optional[List[int]], output_dim: int = 1)

An interface for state-action value (Q-value) estimators (typically, neural networks). These are value neural networks with a special method for computing the Q-value for a state-action pair.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class TwoTowerNetwork(QValueNetwork):
    def __init__(
        self,
        state_input_dim: int,
        action_input_dim: int,
        state_output_dim: int,
        action_output_dim: int,
        state_hidden_dims: Optional[List[int]],
        action_hidden_dims: Optional[List[int]],
        hidden_dims: Optional[List[int]],
        output_dim: int = 1,
    ) -> None:

        super(TwoTowerNetwork, self).__init__()

        """
        Input: batch of state, batch of action. Output: batch of Q-values for (s,a) pairs
        The two tower archtecture is as follows:
        state ----> state_feature
                            | concat ----> Q(s,a)
        action ----> action_feature
        """
        self._state_input_dim = state_input_dim
        self._action_input_dim = action_input_dim
        self._state_features = VanillaValueNetwork(
            input_dim=state_input_dim,
            hidden_dims=state_hidden_dims,
            output_dim=state_output_dim,
        )
        self._state_features.xavier_init()
        self._action_features = VanillaValueNetwork(
            input_dim=action_input_dim,
            hidden_dims=action_hidden_dims,
            output_dim=action_output_dim,
        )
        self._action_features.xavier_init()
        self._interaction_features = VanillaValueNetwork(
            input_dim=state_output_dim + action_output_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
        )
        self._interaction_features.xavier_init()

    def forward(self, state_action: Tensor) -> Tensor:
        state = state_action[..., : self._state_input_dim]
        action = state_action[..., self._state_input_dim :]
        output = self.get_q_values(state_batch=state, action_batch=action)
        return output

    def get_q_values(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
        curr_available_actions_batch: Optional[Tensor] = None,
    ) -> Tensor:
        state_batch_features = self._state_features.forward(state_batch)
        """ this might need to be done in tensor_based_replay_buffer """
        action_batch_features = self._action_features.forward(
            action_batch.to(torch.get_default_dtype())
        )

        x = torch.cat([state_batch_features, action_batch_features], dim=-1)
        return self._interaction_features.forward(x).view(-1)  # (batch_size)

    @property
    def state_dim(self) -> int:
        return self._state_input_dim

    @property
    def action_dim(self) -> int:
        return self._action_input_dim

Ancestors

Subclasses

Methods

def forward(self, state_action: torch.Tensor) ‑> torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, state_action: Tensor) -> Tensor:
    state = state_action[..., : self._state_input_dim]
    action = state_action[..., self._state_input_dim :]
    output = self.get_q_values(state_batch=state, action_batch=action)
    return output

Inherited members

class TwoTowerQValueNetwork (state_dim: int, action_dim: int, hidden_dims: List[int], output_dim: int = 1, state_output_dim: Optional[int] = None, action_output_dim: Optional[int] = None, state_hidden_dims: Optional[List[int]] = None, action_hidden_dims: Optional[List[int]] = None)

An interface for state-action value (Q-value) estimators (typically, neural networks). These are value neural networks with a special method for computing the Q-value for a state-action pair.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class TwoTowerQValueNetwork(TwoTowerNetwork):
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: List[int],
        output_dim: int = 1,
        state_output_dim: Optional[int] = None,
        action_output_dim: Optional[int] = None,
        state_hidden_dims: Optional[List[int]] = None,
        action_hidden_dims: Optional[List[int]] = None,
    ) -> None:

        super().__init__(
            state_input_dim=state_dim,
            action_input_dim=action_dim,
            state_output_dim=state_dim
            if state_output_dim is None
            else state_output_dim,
            action_output_dim=action_dim
            if action_output_dim is None
            else action_output_dim,
            state_hidden_dims=[] if state_hidden_dims is None else state_hidden_dims,
            action_hidden_dims=[] if action_hidden_dims is None else action_hidden_dims,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
        )

Ancestors

Inherited members

class ValueNetwork (*args, **kwargs)

An interface for value neural networks. It does not add any required methods to those already present in its super classes. Its purpose instead is just to serve as an umbrella type for all value networks.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class ValueNetwork(nn.Module, ABC):
    """
    An interface for value neural networks.
    It does not add any required methods to those already present in
    its super classes.
    Its purpose instead is just to serve as an umbrella type for all value networks.
    """

Ancestors

  • torch.nn.modules.module.Module
  • abc.ABC

Subclasses

class VanillaCNN (input_width: int, input_height: int, input_channels_count: int, kernel_sizes: List[int], output_channels_list: List[int], strides: List[int], paddings: List[int], hidden_dims_fully_connected: Optional[List[int]] = None, use_batch_norm_conv: bool = False, use_batch_norm_fully_connected: bool = False, output_dim: int = 1)

Vanilla CNN with a convolutional block followed by an mlp block.

Args

input_width
width of the input
input_height
height of the input
input_channels_count
number of input channels
kernel_sizes
list of kernel sizes for the convolutional layers
output_channels_list
list of number of output channels for each convolutional layer
strides
list of strides for each layer
paddings
list of paddings for each layer
hidden_dims_fully_connected
a list of dimensions of the hidden layers in the mlp
use_batch_norm_conv
whether to use batch_norm in the convolutional layers
use_batch_norm_fully_connected
whether to use batch_norm in the fully connected layers
output_dim
dimension of the output layer

Returns

An nn.Sequential module consisting of a convolutional block followed by an mlp. Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class VanillaCNN(ValueNetwork):
    """
    Vanilla CNN with a convolutional block followed by an mlp block.
    Args:
        input_width: width of the input
        input_height: height of the input
        input_channels_count: number of input channels
        kernel_sizes: list of kernel sizes for the convolutional layers
        output_channels_list: list of number of output channels for each convolutional layer
        strides: list of strides for each layer
        paddings: list of paddings for each layer
        hidden_dims_fully_connected: a list of dimensions of the hidden layers in the mlp
        use_batch_norm_conv: whether to use batch_norm in the convolutional layers
        use_batch_norm_fully_connected: whether to use batch_norm in the fully connected layers
        output_dim: dimension of the output layer
    Returns:
        An nn.Sequential module consisting of a convolutional block followed by an mlp.
    """

    def __init__(
        self,
        input_width: int,
        input_height: int,
        input_channels_count: int,
        kernel_sizes: List[int],
        output_channels_list: List[int],
        strides: List[int],
        paddings: List[int],
        hidden_dims_fully_connected: Optional[
            List[int]
        ] = None,  # hidden dims for fully connected layers
        use_batch_norm_conv: bool = False,
        use_batch_norm_fully_connected: bool = False,
        output_dim: int = 1,  # dimension of the final output
    ) -> None:

        assert (
            len(kernel_sizes)
            == len(output_channels_list)
            == len(strides)
            == len(paddings)
        )
        super(VanillaCNN, self).__init__()

        self._input_channels = input_channels_count
        self._input_height = input_height
        self._input_width = input_width
        self._output_channels = output_channels_list
        self._kernel_sizes = kernel_sizes
        self._strides = strides
        self._paddings = paddings
        if hidden_dims_fully_connected is None:
            self._hidden_dims_fully_connected: List[int] = []
        else:
            self._hidden_dims_fully_connected: List[int] = hidden_dims_fully_connected

        self._use_batch_norm_conv = use_batch_norm_conv
        self._use_batch_norm_fully_connected = use_batch_norm_fully_connected
        self._output_dim = output_dim

        self._model_cnn: nn.Module = conv_block(
            input_channels_count=self._input_channels,
            output_channels_list=self._output_channels,
            kernel_sizes=self._kernel_sizes,
            strides=self._strides,
            paddings=self._paddings,
            use_batch_norm=self._use_batch_norm_conv,
        )

        self._mlp_input_dims: int = self.compute_output_dim_model_cnn()
        self._model_fc: nn.Module = mlp_block(
            input_dim=self._mlp_input_dims,
            hidden_dims=self._hidden_dims_fully_connected,
            output_dim=self._output_dim,
            use_batch_norm=self._use_batch_norm_fully_connected,
        )

    def compute_output_dim_model_cnn(self) -> int:
        dummy_input = torch.zeros(
            1, self._input_channels, self._input_width, self._input_height
        )
        dummy_output_flattened = torch.flatten(
            self._model_cnn(dummy_input), start_dim=1, end_dim=-1
        )
        return dummy_output_flattened.shape[1]

    def forward(self, x: Tensor) -> Tensor:
        out_cnn = self._model_cnn(x)
        out_flattened = torch.flatten(out_cnn, start_dim=1, end_dim=-1)
        out_fc = self._model_fc(out_flattened)
        return out_fc

Ancestors

Subclasses

Methods

def compute_output_dim_model_cnn(self) ‑> int
Expand source code
def compute_output_dim_model_cnn(self) -> int:
    dummy_input = torch.zeros(
        1, self._input_channels, self._input_width, self._input_height
    )
    dummy_output_flattened = torch.flatten(
        self._model_cnn(dummy_input), start_dim=1, end_dim=-1
    )
    return dummy_output_flattened.shape[1]
def forward(self, x: torch.Tensor) ‑> torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x: Tensor) -> Tensor:
    out_cnn = self._model_cnn(x)
    out_flattened = torch.flatten(out_cnn, start_dim=1, end_dim=-1)
    out_fc = self._model_fc(out_flattened)
    return out_fc
class VanillaQValueNetwork (state_dim: int, action_dim: int, hidden_dims: List[int], output_dim: int, use_layer_norm: bool = False)

A vanilla version of state-action value (Q-value) network. It leverages the vanilla implementation of value networks by using the state-action pair as the input for the value network.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class VanillaQValueNetwork(QValueNetwork):
    """
    A vanilla version of state-action value (Q-value) network.
    It leverages the vanilla implementation of value networks by
    using the state-action pair as the input for the value network.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: List[int],
        output_dim: int,
        use_layer_norm: bool = False,
    ) -> None:
        super(VanillaQValueNetwork, self).__init__()
        self._state_dim: int = state_dim
        self._action_dim: int = action_dim
        self._model: nn.Module = mlp_block(
            input_dim=state_dim + action_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
            use_layer_norm=use_layer_norm,
        )

    def forward(self, x: Tensor) -> Tensor:
        return self._model(x)

    def get_q_values(
        self,
        state_batch: Tensor,
        action_batch: Tensor,
        curr_available_actions_batch: Optional[Tensor] = None,
    ) -> Tensor:
        x = torch.cat([state_batch, action_batch], dim=-1)
        return self.forward(x).view(-1)

    @property
    def state_dim(self) -> int:
        return self._state_dim

    @property
    def action_dim(self) -> int:
        return self._action_dim

Ancestors

Methods

def forward(self, x: torch.Tensor) ‑> torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x: Tensor) -> Tensor:
    return self._model(x)

Inherited members

class VanillaValueNetwork (input_dim: int, hidden_dims: Optional[List[int]], output_dim: int = 1, **kwargs: Any)

An interface for value neural networks. It does not add any required methods to those already present in its super classes. Its purpose instead is just to serve as an umbrella type for all value networks.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class VanillaValueNetwork(ValueNetwork):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: Optional[List[int]],
        output_dim: int = 1,
        **kwargs: Any,
    ) -> None:
        super(VanillaValueNetwork, self).__init__()
        self._model: nn.Module = mlp_block(
            input_dim=input_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
            **kwargs,
        )

    def forward(self, x: Tensor) -> Tensor:
        return self._model(x)

    # default initialization in linear and conv layers of a F.sequential model is Kaiming
    def xavier_init(self) -> None:
        for layer in self._model:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)

Ancestors

Methods

def forward(self, x: torch.Tensor) ‑> torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x: Tensor) -> Tensor:
    return self._model(x)
def xavier_init(self) ‑> None
Expand source code
def xavier_init(self) -> None:
    for layer in self._model:
        if isinstance(layer, nn.Linear):
            nn.init.xavier_normal_(layer.weight)