Module pearl.policy_learners.sequential_decision_making.actor_critic_base

Expand source code
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
from abc import abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Type

from pearl.action_representation_modules.action_representation_module import (
    ActionRepresentationModule,
)

from pearl.neural_networks.common.value_networks import QValueNetwork
from pearl.neural_networks.sequential_decision_making.actor_networks import (
    ActorNetwork,
    DynamicActionActorNetwork,
)

from pearl.utils.instantiations.spaces.box_action import BoxActionSpace

from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace

try:
    import gymnasium as gym
except ModuleNotFoundError:
    import gym

import torch

from pearl.api.action import Action

from pearl.api.action_space import ActionSpace
from pearl.api.state import SubjectiveState
from pearl.history_summarization_modules.history_summarization_module import (
    HistorySummarizationModule,
)
from pearl.neural_networks.common.utils import (
    init_weights,
    update_target_network,
    update_target_networks,
)
from pearl.neural_networks.common.value_networks import (
    VanillaQValueNetwork,
    VanillaValueNetwork,
)
from pearl.neural_networks.sequential_decision_making.actor_networks import (
    VanillaActorNetwork,
)
from pearl.neural_networks.sequential_decision_making.twin_critic import TwinCritic
from pearl.policy_learners.exploration_modules.exploration_module import (
    ExplorationModule,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from torch import nn, optim


class ActorCriticBase(PolicyLearner):
    """
    A base class for all actor-critic based policy learners.
    Many components are common to actor-critic methods.
        - Actor and critic (as well as target networks) network initializations.
        - Act, reset and learn_batch methods.
        - Utility functions used by many actor-critic methods.
    """

    def __init__(
        self,
        state_dim: int,
        exploration_module: ExplorationModule,
        actor_hidden_dims: List[int],
        critic_hidden_dims: Optional[List[int]] = None,
        action_space: Optional[ActionSpace] = None,
        actor_learning_rate: float = 1e-3,
        critic_learning_rate: float = 1e-3,
        actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
        critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
        use_actor_target: bool = False,
        use_critic_target: bool = False,
        actor_soft_update_tau: float = 0.005,
        critic_soft_update_tau: float = 0.005,
        use_twin_critic: bool = False,
        discount_factor: float = 0.99,
        training_rounds: int = 1,
        batch_size: int = 256,
        is_action_continuous: bool = False,
        on_policy: bool = False,
        action_representation_module: Optional[ActionRepresentationModule] = None,
    ) -> None:
        super(ActorCriticBase, self).__init__(
            on_policy=on_policy,
            is_action_continuous=is_action_continuous,
            training_rounds=training_rounds,
            batch_size=batch_size,
            exploration_module=exploration_module,
            action_representation_module=action_representation_module,
            action_space=action_space,
        )
        self._state_dim = state_dim
        self._use_actor_target = use_actor_target
        self._use_critic_target = use_critic_target
        self._use_twin_critic = use_twin_critic
        self._use_critic: bool = critic_hidden_dims is not None

        self._action_dim: int = (
            self.action_representation_module.representation_dim
            if self.is_action_continuous
            else self.action_representation_module.max_number_actions
        )

        # actor network takes state as input and outputs an action vector
        self._actor: nn.Module = actor_network_type(
            input_dim=state_dim + self._action_dim
            if actor_network_type is DynamicActionActorNetwork
            else state_dim,
            hidden_dims=actor_hidden_dims,
            output_dim=1
            if actor_network_type is DynamicActionActorNetwork
            else self._action_dim,
            action_space=action_space,
        )
        self._actor.apply(init_weights)
        self._actor_optimizer = optim.AdamW(
            [
                {
                    "params": self._actor.parameters(),
                    "lr": actor_learning_rate,
                    "amsgrad": True,
                },
            ]
        )
        self._actor_soft_update_tau = actor_soft_update_tau
        if self._use_actor_target:
            self._actor_target: nn.Module = actor_network_type(
                input_dim=state_dim + self._action_dim
                if actor_network_type is DynamicActionActorNetwork
                else state_dim,
                hidden_dims=actor_hidden_dims,
                output_dim=1
                if actor_network_type is DynamicActionActorNetwork
                else self._action_dim,
                action_space=action_space,
            )
            update_target_network(self._actor_target, self._actor, tau=1)

        self._critic_soft_update_tau = critic_soft_update_tau
        if self._use_critic:
            self._critic: nn.Module = make_critic(
                state_dim=self._state_dim,
                action_dim=self._action_dim,
                hidden_dims=critic_hidden_dims,
                use_twin_critic=use_twin_critic,
                network_type=critic_network_type,
            )
            self._critic_optimizer: optim.Optimizer = optim.AdamW(
                [
                    {
                        "params": self._critic.parameters(),
                        "lr": critic_learning_rate,
                        "amsgrad": True,
                    },
                ]
            )
            if self._use_critic_target:
                self._critic_target: nn.Module = make_critic(
                    state_dim=self._state_dim,
                    action_dim=self._action_dim,
                    hidden_dims=critic_hidden_dims,
                    use_twin_critic=use_twin_critic,
                    network_type=critic_network_type,
                )
                update_critic_target_network(
                    self._critic_target,
                    self._critic,
                    use_twin_critic,
                    1,
                )

        self._discount_factor = discount_factor

    def set_history_summarization_module(
        self, value: HistorySummarizationModule
    ) -> None:
        self._actor_optimizer.add_param_group({"params": value.parameters()})
        if self._use_critic:
            self._critic_optimizer.add_param_group({"params": value.parameters()})
        self._history_summarization_module = value

    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: ActionSpace,
        exploit: bool = False,
    ) -> Action:
        # Step 1: compute exploit_action
        # (action computed by actor network; and without any exploration)
        with torch.no_grad():
            if self.is_action_continuous:
                exploit_action = self._actor.sample_action(subjective_state)
                action_probabilities = None
            else:
                assert isinstance(available_action_space, DiscreteActionSpace)
                actions = self.action_representation_module(
                    available_action_space.actions_batch
                )
                action_probabilities = self._actor.get_policy_distribution(
                    state_batch=subjective_state,
                    available_actions=actions,
                )
                # (action_space_size)
                exploit_action = torch.argmax(action_probabilities)

        # Step 2: return exploit action if no exploration,
        # else pass through the exploration module
        if exploit:
            return exploit_action

        return self._exploration_module.act(
            exploit_action=exploit_action,
            action_space=available_action_space,
            subjective_state=subjective_state,
            values=action_probabilities,
        )

    def reset(self, action_space: ActionSpace) -> None:
        self._action_space = action_space

    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        self._critic_learn_batch(batch)  # update critic
        self._actor_learn_batch(batch)  # update actor

        if self._use_critic_target:
            update_critic_target_network(
                self._critic_target,
                self._critic,
                self._use_twin_critic,
                self._critic_soft_update_tau,
            )
        if self._use_actor_target:
            update_target_network(
                self._actor_target,
                self._actor,
                self._actor_soft_update_tau,
            )
        return {}

    @abstractmethod
    def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        pass

    @abstractmethod
    def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        pass


def make_critic(
    state_dim: int,
    hidden_dims: Optional[Iterable[int]],
    use_twin_critic: bool,
    network_type: Type[QValueNetwork],
    action_dim: Optional[int] = None,
) -> nn.Module:
    if use_twin_critic:
        assert action_dim is not None
        assert hidden_dims is not None
        return TwinCritic(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dims=hidden_dims,
            network_type=network_type,
            init_fn=init_weights,
        )
    else:
        if network_type == VanillaQValueNetwork:
            # pyre-ignore[45]:
            # Pyre does not know that `network_type` is asserted to be concrete
            return network_type(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dims=hidden_dims,
                output_dim=1,
            )
        elif network_type == VanillaValueNetwork:
            # pyre-ignore[45]:
            # Pyre does not know that `network_type` is asserted to be concrete
            return network_type(
                input_dim=state_dim, hidden_dims=hidden_dims, output_dim=1
            )
        else:
            raise NotImplementedError(
                "Unknown network type. The code needs to be refactored to support this."
            )


def update_critic_target_network(
    target_network: nn.Module, network: nn.Module, use_twin_critic: bool, tau: float
) -> None:
    if use_twin_critic:
        update_target_networks(
            target_network._critic_networks_combined,
            network._critic_networks_combined,
            tau=tau,
        )
    else:
        update_target_network(
            target_network._model,
            network._model,
            tau=tau,
        )


def single_critic_state_value_update(
    state_batch: torch.Tensor,
    expected_target_batch: torch.Tensor,
    optimizer: torch.optim.Optimizer,
    critic: nn.Module,
) -> Dict[str, Any]:
    vs = critic(state_batch)
    # critic loss
    criterion = torch.nn.MSELoss()
    loss = criterion(
        vs.reshape_as(expected_target_batch), expected_target_batch.detach()
    )
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return {"critic_loss": loss.mean().item()}


def twin_critic_action_value_update(
    state_batch: torch.Tensor,
    action_batch: torch.Tensor,
    expected_target_batch: torch.Tensor,
    optimizer: torch.optim.Optimizer,
    critic: TwinCritic,
) -> Dict[str, torch.Tensor]:
    """
    Performs an optimization step on the twin critic networks.

    Args:
        state_batch: a batch of states with shape (batch_size, state_dim)
        action_batch: a batch of actions with shape (batch_size, action_dim)
        expected_target: the batch of target estimates for Bellman equation.
        optimizer: the optimizer to use for the update.
        critic: the critic network to update.
    Returns:
        Dict[str, torch.Tensor]: mean loss and individual critic losses.
    """

    criterion = torch.nn.MSELoss()
    optimizer.zero_grad()
    q_1, q_2 = critic.get_q_values(state_batch, action_batch)
    loss = criterion(
        q_1.reshape_as(expected_target_batch), expected_target_batch.detach()
    ) + criterion(q_2.reshape_as(expected_target_batch), expected_target_batch.detach())
    loss.backward()
    optimizer.step()

    return {
        "critic_mean_loss": loss.item(),
        "critic_1_values": q_1.mean().item(),
        "critic_2_values": q_2.mean().item(),
    }

Functions

def make_critic(state_dim: int, hidden_dims: Optional[Iterable[int]], use_twin_critic: bool, network_type: Type[QValueNetwork], action_dim: Optional[int] = None) ‑> torch.nn.modules.module.Module
Expand source code
def make_critic(
    state_dim: int,
    hidden_dims: Optional[Iterable[int]],
    use_twin_critic: bool,
    network_type: Type[QValueNetwork],
    action_dim: Optional[int] = None,
) -> nn.Module:
    if use_twin_critic:
        assert action_dim is not None
        assert hidden_dims is not None
        return TwinCritic(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dims=hidden_dims,
            network_type=network_type,
            init_fn=init_weights,
        )
    else:
        if network_type == VanillaQValueNetwork:
            # pyre-ignore[45]:
            # Pyre does not know that `network_type` is asserted to be concrete
            return network_type(
                state_dim=state_dim,
                action_dim=action_dim,
                hidden_dims=hidden_dims,
                output_dim=1,
            )
        elif network_type == VanillaValueNetwork:
            # pyre-ignore[45]:
            # Pyre does not know that `network_type` is asserted to be concrete
            return network_type(
                input_dim=state_dim, hidden_dims=hidden_dims, output_dim=1
            )
        else:
            raise NotImplementedError(
                "Unknown network type. The code needs to be refactored to support this."
            )
def single_critic_state_value_update(state_batch: torch.Tensor, expected_target_batch: torch.Tensor, optimizer: torch.optim.optimizer.Optimizer, critic: torch.nn.modules.module.Module) ‑> Dict[str, Any]
Expand source code
def single_critic_state_value_update(
    state_batch: torch.Tensor,
    expected_target_batch: torch.Tensor,
    optimizer: torch.optim.Optimizer,
    critic: nn.Module,
) -> Dict[str, Any]:
    vs = critic(state_batch)
    # critic loss
    criterion = torch.nn.MSELoss()
    loss = criterion(
        vs.reshape_as(expected_target_batch), expected_target_batch.detach()
    )
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return {"critic_loss": loss.mean().item()}
def twin_critic_action_value_update(state_batch: torch.Tensor, action_batch: torch.Tensor, expected_target_batch: torch.Tensor, optimizer: torch.optim.optimizer.Optimizer, critic: TwinCritic) ‑> Dict[str, torch.Tensor]

Performs an optimization step on the twin critic networks.

Args

state_batch
a batch of states with shape (batch_size, state_dim)
action_batch
a batch of actions with shape (batch_size, action_dim)
expected_target
the batch of target estimates for Bellman equation.
optimizer
the optimizer to use for the update.
critic
the critic network to update.

Returns

Dict[str, torch.Tensor]
mean loss and individual critic losses.
Expand source code
def twin_critic_action_value_update(
    state_batch: torch.Tensor,
    action_batch: torch.Tensor,
    expected_target_batch: torch.Tensor,
    optimizer: torch.optim.Optimizer,
    critic: TwinCritic,
) -> Dict[str, torch.Tensor]:
    """
    Performs an optimization step on the twin critic networks.

    Args:
        state_batch: a batch of states with shape (batch_size, state_dim)
        action_batch: a batch of actions with shape (batch_size, action_dim)
        expected_target: the batch of target estimates for Bellman equation.
        optimizer: the optimizer to use for the update.
        critic: the critic network to update.
    Returns:
        Dict[str, torch.Tensor]: mean loss and individual critic losses.
    """

    criterion = torch.nn.MSELoss()
    optimizer.zero_grad()
    q_1, q_2 = critic.get_q_values(state_batch, action_batch)
    loss = criterion(
        q_1.reshape_as(expected_target_batch), expected_target_batch.detach()
    ) + criterion(q_2.reshape_as(expected_target_batch), expected_target_batch.detach())
    loss.backward()
    optimizer.step()

    return {
        "critic_mean_loss": loss.item(),
        "critic_1_values": q_1.mean().item(),
        "critic_2_values": q_2.mean().item(),
    }
def update_critic_target_network(target_network: torch.nn.modules.module.Module, network: torch.nn.modules.module.Module, use_twin_critic: bool, tau: float) ‑> None
Expand source code
def update_critic_target_network(
    target_network: nn.Module, network: nn.Module, use_twin_critic: bool, tau: float
) -> None:
    if use_twin_critic:
        update_target_networks(
            target_network._critic_networks_combined,
            network._critic_networks_combined,
            tau=tau,
        )
    else:
        update_target_network(
            target_network._model,
            network._model,
            tau=tau,
        )

Classes

class ActorCriticBase (state_dim: int, exploration_module: ExplorationModule, actor_hidden_dims: List[int], critic_hidden_dims: Optional[List[int]] = None, action_space: Optional[ActionSpace] = None, actor_learning_rate: float = 0.001, critic_learning_rate: float = 0.001, actor_network_type: Type[ActorNetwork] = pearl.neural_networks.sequential_decision_making.actor_networks.VanillaActorNetwork, critic_network_type: Type[QValueNetwork] = pearl.neural_networks.common.value_networks.VanillaQValueNetwork, use_actor_target: bool = False, use_critic_target: bool = False, actor_soft_update_tau: float = 0.005, critic_soft_update_tau: float = 0.005, use_twin_critic: bool = False, discount_factor: float = 0.99, training_rounds: int = 1, batch_size: int = 256, is_action_continuous: bool = False, on_policy: bool = False, action_representation_module: Optional[ActionRepresentationModule] = None)

A base class for all actor-critic based policy learners. Many components are common to actor-critic methods. - Actor and critic (as well as target networks) network initializations. - Act, reset and learn_batch methods. - Utility functions used by many actor-critic methods.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class ActorCriticBase(PolicyLearner):
    """
    A base class for all actor-critic based policy learners.
    Many components are common to actor-critic methods.
        - Actor and critic (as well as target networks) network initializations.
        - Act, reset and learn_batch methods.
        - Utility functions used by many actor-critic methods.
    """

    def __init__(
        self,
        state_dim: int,
        exploration_module: ExplorationModule,
        actor_hidden_dims: List[int],
        critic_hidden_dims: Optional[List[int]] = None,
        action_space: Optional[ActionSpace] = None,
        actor_learning_rate: float = 1e-3,
        critic_learning_rate: float = 1e-3,
        actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
        critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
        use_actor_target: bool = False,
        use_critic_target: bool = False,
        actor_soft_update_tau: float = 0.005,
        critic_soft_update_tau: float = 0.005,
        use_twin_critic: bool = False,
        discount_factor: float = 0.99,
        training_rounds: int = 1,
        batch_size: int = 256,
        is_action_continuous: bool = False,
        on_policy: bool = False,
        action_representation_module: Optional[ActionRepresentationModule] = None,
    ) -> None:
        super(ActorCriticBase, self).__init__(
            on_policy=on_policy,
            is_action_continuous=is_action_continuous,
            training_rounds=training_rounds,
            batch_size=batch_size,
            exploration_module=exploration_module,
            action_representation_module=action_representation_module,
            action_space=action_space,
        )
        self._state_dim = state_dim
        self._use_actor_target = use_actor_target
        self._use_critic_target = use_critic_target
        self._use_twin_critic = use_twin_critic
        self._use_critic: bool = critic_hidden_dims is not None

        self._action_dim: int = (
            self.action_representation_module.representation_dim
            if self.is_action_continuous
            else self.action_representation_module.max_number_actions
        )

        # actor network takes state as input and outputs an action vector
        self._actor: nn.Module = actor_network_type(
            input_dim=state_dim + self._action_dim
            if actor_network_type is DynamicActionActorNetwork
            else state_dim,
            hidden_dims=actor_hidden_dims,
            output_dim=1
            if actor_network_type is DynamicActionActorNetwork
            else self._action_dim,
            action_space=action_space,
        )
        self._actor.apply(init_weights)
        self._actor_optimizer = optim.AdamW(
            [
                {
                    "params": self._actor.parameters(),
                    "lr": actor_learning_rate,
                    "amsgrad": True,
                },
            ]
        )
        self._actor_soft_update_tau = actor_soft_update_tau
        if self._use_actor_target:
            self._actor_target: nn.Module = actor_network_type(
                input_dim=state_dim + self._action_dim
                if actor_network_type is DynamicActionActorNetwork
                else state_dim,
                hidden_dims=actor_hidden_dims,
                output_dim=1
                if actor_network_type is DynamicActionActorNetwork
                else self._action_dim,
                action_space=action_space,
            )
            update_target_network(self._actor_target, self._actor, tau=1)

        self._critic_soft_update_tau = critic_soft_update_tau
        if self._use_critic:
            self._critic: nn.Module = make_critic(
                state_dim=self._state_dim,
                action_dim=self._action_dim,
                hidden_dims=critic_hidden_dims,
                use_twin_critic=use_twin_critic,
                network_type=critic_network_type,
            )
            self._critic_optimizer: optim.Optimizer = optim.AdamW(
                [
                    {
                        "params": self._critic.parameters(),
                        "lr": critic_learning_rate,
                        "amsgrad": True,
                    },
                ]
            )
            if self._use_critic_target:
                self._critic_target: nn.Module = make_critic(
                    state_dim=self._state_dim,
                    action_dim=self._action_dim,
                    hidden_dims=critic_hidden_dims,
                    use_twin_critic=use_twin_critic,
                    network_type=critic_network_type,
                )
                update_critic_target_network(
                    self._critic_target,
                    self._critic,
                    use_twin_critic,
                    1,
                )

        self._discount_factor = discount_factor

    def set_history_summarization_module(
        self, value: HistorySummarizationModule
    ) -> None:
        self._actor_optimizer.add_param_group({"params": value.parameters()})
        if self._use_critic:
            self._critic_optimizer.add_param_group({"params": value.parameters()})
        self._history_summarization_module = value

    def act(
        self,
        subjective_state: SubjectiveState,
        available_action_space: ActionSpace,
        exploit: bool = False,
    ) -> Action:
        # Step 1: compute exploit_action
        # (action computed by actor network; and without any exploration)
        with torch.no_grad():
            if self.is_action_continuous:
                exploit_action = self._actor.sample_action(subjective_state)
                action_probabilities = None
            else:
                assert isinstance(available_action_space, DiscreteActionSpace)
                actions = self.action_representation_module(
                    available_action_space.actions_batch
                )
                action_probabilities = self._actor.get_policy_distribution(
                    state_batch=subjective_state,
                    available_actions=actions,
                )
                # (action_space_size)
                exploit_action = torch.argmax(action_probabilities)

        # Step 2: return exploit action if no exploration,
        # else pass through the exploration module
        if exploit:
            return exploit_action

        return self._exploration_module.act(
            exploit_action=exploit_action,
            action_space=available_action_space,
            subjective_state=subjective_state,
            values=action_probabilities,
        )

    def reset(self, action_space: ActionSpace) -> None:
        self._action_space = action_space

    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        self._critic_learn_batch(batch)  # update critic
        self._actor_learn_batch(batch)  # update actor

        if self._use_critic_target:
            update_critic_target_network(
                self._critic_target,
                self._critic,
                self._use_twin_critic,
                self._critic_soft_update_tau,
            )
        if self._use_actor_target:
            update_target_network(
                self._actor_target,
                self._actor,
                self._actor_soft_update_tau,
            )
        return {}

    @abstractmethod
    def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        pass

    @abstractmethod
    def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        pass

Ancestors

Subclasses

Methods

def act(self, subjective_state: torch.Tensor, available_action_space: ActionSpace, exploit: bool = False) ‑> torch.Tensor
Expand source code
def act(
    self,
    subjective_state: SubjectiveState,
    available_action_space: ActionSpace,
    exploit: bool = False,
) -> Action:
    # Step 1: compute exploit_action
    # (action computed by actor network; and without any exploration)
    with torch.no_grad():
        if self.is_action_continuous:
            exploit_action = self._actor.sample_action(subjective_state)
            action_probabilities = None
        else:
            assert isinstance(available_action_space, DiscreteActionSpace)
            actions = self.action_representation_module(
                available_action_space.actions_batch
            )
            action_probabilities = self._actor.get_policy_distribution(
                state_batch=subjective_state,
                available_actions=actions,
            )
            # (action_space_size)
            exploit_action = torch.argmax(action_probabilities)

    # Step 2: return exploit action if no exploration,
    # else pass through the exploration module
    if exploit:
        return exploit_action

    return self._exploration_module.act(
        exploit_action=exploit_action,
        action_space=available_action_space,
        subjective_state=subjective_state,
        values=action_probabilities,
    )
def set_history_summarization_module(self, value: HistorySummarizationModule) ‑> None
Expand source code
def set_history_summarization_module(
    self, value: HistorySummarizationModule
) -> None:
    self._actor_optimizer.add_param_group({"params": value.parameters()})
    if self._use_critic:
        self._critic_optimizer.add_param_group({"params": value.parameters()})
    self._history_summarization_module = value

Inherited members