Module pearl.neural_networks.sequential_decision_making.twin_critic

Expand source code
import inspect
from typing import Callable, Iterable, Tuple, Type

import torch
import torch.nn as nn
from pearl.neural_networks.common.value_networks import (
    QValueNetwork,
    VanillaQValueNetwork,
)


class TwinCritic(torch.nn.Module):
    """
    This is a wrapper for using two critic networks to reduce overestimation bias in
    critic estimation. Each critic is initialized differently by a given
    initialization function.

    NOTE: For more than two critics, the standard way is to use nn.ModuleList()
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: Iterable[int],
        init_fn: Callable[[nn.Module], None],
        network_type: Type[QValueNetwork] = VanillaQValueNetwork,
        output_dim: int = 1,
    ) -> None:
        super(TwinCritic, self).__init__()

        if inspect.isabstract(network_type):
            raise ValueError("network_type must not be abstract")

        # pyre-ignore[45]:
        # Pyre does not know that `network_type` is asserted to be concrete
        self._critic_1: QValueNetwork = network_type(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
        )

        # pyre-ignore[45]:
        # Pyre does not know that `network_type` is asserted to be concrete
        self._critic_2: QValueNetwork = network_type(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
        )

        # nn.ModuleList helps manage the networks
        # (initialization, parameter update etc.) efficiently
        self._critic_networks_combined = nn.ModuleList([self._critic_1, self._critic_2])
        self._critic_networks_combined.apply(init_fn)

    def get_q_values(
        self,
        state_batch: torch.Tensor,
        action_batch: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            state_batch (torch.Tensor): a batch of states with shape (batch_size, state_dim)
            action_batch (torch.Tensor): a batch of actions with shape (batch_size, action_dim)
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Q-values of (state, action) pairs with shape
            (batch_size)
        """
        critic_1_values = self._critic_1.get_q_values(state_batch, action_batch)
        critic_2_values = self._critic_2.get_q_values(state_batch, action_batch)
        return critic_1_values, critic_2_values

Classes

class TwinCritic (state_dim: int, action_dim: int, hidden_dims: Iterable[int], init_fn: Callable[[torch.nn.modules.module.Module], None], network_type: Type[QValueNetwork] = pearl.neural_networks.common.value_networks.VanillaQValueNetwork, output_dim: int = 1)

This is a wrapper for using two critic networks to reduce overestimation bias in critic estimation. Each critic is initialized differently by a given initialization function.

NOTE: For more than two critics, the standard way is to use nn.ModuleList()

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class TwinCritic(torch.nn.Module):
    """
    This is a wrapper for using two critic networks to reduce overestimation bias in
    critic estimation. Each critic is initialized differently by a given
    initialization function.

    NOTE: For more than two critics, the standard way is to use nn.ModuleList()
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dims: Iterable[int],
        init_fn: Callable[[nn.Module], None],
        network_type: Type[QValueNetwork] = VanillaQValueNetwork,
        output_dim: int = 1,
    ) -> None:
        super(TwinCritic, self).__init__()

        if inspect.isabstract(network_type):
            raise ValueError("network_type must not be abstract")

        # pyre-ignore[45]:
        # Pyre does not know that `network_type` is asserted to be concrete
        self._critic_1: QValueNetwork = network_type(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
        )

        # pyre-ignore[45]:
        # Pyre does not know that `network_type` is asserted to be concrete
        self._critic_2: QValueNetwork = network_type(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dims=hidden_dims,
            output_dim=output_dim,
        )

        # nn.ModuleList helps manage the networks
        # (initialization, parameter update etc.) efficiently
        self._critic_networks_combined = nn.ModuleList([self._critic_1, self._critic_2])
        self._critic_networks_combined.apply(init_fn)

    def get_q_values(
        self,
        state_batch: torch.Tensor,
        action_batch: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            state_batch (torch.Tensor): a batch of states with shape (batch_size, state_dim)
            action_batch (torch.Tensor): a batch of actions with shape (batch_size, action_dim)
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Q-values of (state, action) pairs with shape
            (batch_size)
        """
        critic_1_values = self._critic_1.get_q_values(state_batch, action_batch)
        critic_2_values = self._critic_2.get_q_values(state_batch, action_batch)
        return critic_1_values, critic_2_values

Ancestors

  • torch.nn.modules.module.Module

Methods

def get_q_values(self, state_batch: torch.Tensor, action_batch: torch.Tensor) ‑> Tuple[torch.Tensor, torch.Tensor]

Args

state_batch : torch.Tensor
a batch of states with shape (batch_size, state_dim)
action_batch : torch.Tensor
a batch of actions with shape (batch_size, action_dim)

Returns

Tuple[torch.Tensor, torch.Tensor]
Q-values of (state, action) pairs with shape

(batch_size)

Expand source code
def get_q_values(
    self,
    state_batch: torch.Tensor,
    action_batch: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
        state_batch (torch.Tensor): a batch of states with shape (batch_size, state_dim)
        action_batch (torch.Tensor): a batch of actions with shape (batch_size, action_dim)
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Q-values of (state, action) pairs with shape
        (batch_size)
    """
    critic_1_values = self._critic_1.get_q_values(state_batch, action_batch)
    critic_2_values = self._critic_2.get_q_values(state_batch, action_batch)
    return critic_1_values, critic_2_values