Module pearl.policy_learners.sequential_decision_making.actor_critic_base
Expand source code
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
from abc import abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Type
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
)
from pearl.neural_networks.common.value_networks import QValueNetwork
from pearl.neural_networks.sequential_decision_making.actor_networks import (
ActorNetwork,
DynamicActionActorNetwork,
)
from pearl.utils.instantiations.spaces.box_action import BoxActionSpace
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
try:
import gymnasium as gym
except ModuleNotFoundError:
import gym
import torch
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.api.state import SubjectiveState
from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
)
from pearl.neural_networks.common.utils import (
init_weights,
update_target_network,
update_target_networks,
)
from pearl.neural_networks.common.value_networks import (
VanillaQValueNetwork,
VanillaValueNetwork,
)
from pearl.neural_networks.sequential_decision_making.actor_networks import (
VanillaActorNetwork,
)
from pearl.neural_networks.sequential_decision_making.twin_critic import TwinCritic
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.transition import TransitionBatch
from torch import nn, optim
class ActorCriticBase(PolicyLearner):
"""
A base class for all actor-critic based policy learners.
Many components are common to actor-critic methods.
- Actor and critic (as well as target networks) network initializations.
- Act, reset and learn_batch methods.
- Utility functions used by many actor-critic methods.
"""
def __init__(
self,
state_dim: int,
exploration_module: ExplorationModule,
actor_hidden_dims: List[int],
critic_hidden_dims: Optional[List[int]] = None,
action_space: Optional[ActionSpace] = None,
actor_learning_rate: float = 1e-3,
critic_learning_rate: float = 1e-3,
actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
use_actor_target: bool = False,
use_critic_target: bool = False,
actor_soft_update_tau: float = 0.005,
critic_soft_update_tau: float = 0.005,
use_twin_critic: bool = False,
discount_factor: float = 0.99,
training_rounds: int = 1,
batch_size: int = 256,
is_action_continuous: bool = False,
on_policy: bool = False,
action_representation_module: Optional[ActionRepresentationModule] = None,
) -> None:
super(ActorCriticBase, self).__init__(
on_policy=on_policy,
is_action_continuous=is_action_continuous,
training_rounds=training_rounds,
batch_size=batch_size,
exploration_module=exploration_module,
action_representation_module=action_representation_module,
action_space=action_space,
)
self._state_dim = state_dim
self._use_actor_target = use_actor_target
self._use_critic_target = use_critic_target
self._use_twin_critic = use_twin_critic
self._use_critic: bool = critic_hidden_dims is not None
self._action_dim: int = (
self.action_representation_module.representation_dim
if self.is_action_continuous
else self.action_representation_module.max_number_actions
)
# actor network takes state as input and outputs an action vector
self._actor: nn.Module = actor_network_type(
input_dim=state_dim + self._action_dim
if actor_network_type is DynamicActionActorNetwork
else state_dim,
hidden_dims=actor_hidden_dims,
output_dim=1
if actor_network_type is DynamicActionActorNetwork
else self._action_dim,
action_space=action_space,
)
self._actor.apply(init_weights)
self._actor_optimizer = optim.AdamW(
[
{
"params": self._actor.parameters(),
"lr": actor_learning_rate,
"amsgrad": True,
},
]
)
self._actor_soft_update_tau = actor_soft_update_tau
if self._use_actor_target:
self._actor_target: nn.Module = actor_network_type(
input_dim=state_dim + self._action_dim
if actor_network_type is DynamicActionActorNetwork
else state_dim,
hidden_dims=actor_hidden_dims,
output_dim=1
if actor_network_type is DynamicActionActorNetwork
else self._action_dim,
action_space=action_space,
)
update_target_network(self._actor_target, self._actor, tau=1)
self._critic_soft_update_tau = critic_soft_update_tau
if self._use_critic:
self._critic: nn.Module = make_critic(
state_dim=self._state_dim,
action_dim=self._action_dim,
hidden_dims=critic_hidden_dims,
use_twin_critic=use_twin_critic,
network_type=critic_network_type,
)
self._critic_optimizer: optim.Optimizer = optim.AdamW(
[
{
"params": self._critic.parameters(),
"lr": critic_learning_rate,
"amsgrad": True,
},
]
)
if self._use_critic_target:
self._critic_target: nn.Module = make_critic(
state_dim=self._state_dim,
action_dim=self._action_dim,
hidden_dims=critic_hidden_dims,
use_twin_critic=use_twin_critic,
network_type=critic_network_type,
)
update_critic_target_network(
self._critic_target,
self._critic,
use_twin_critic,
1,
)
self._discount_factor = discount_factor
def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
self._actor_optimizer.add_param_group({"params": value.parameters()})
if self._use_critic:
self._critic_optimizer.add_param_group({"params": value.parameters()})
self._history_summarization_module = value
def act(
self,
subjective_state: SubjectiveState,
available_action_space: ActionSpace,
exploit: bool = False,
) -> Action:
# Step 1: compute exploit_action
# (action computed by actor network; and without any exploration)
with torch.no_grad():
if self.is_action_continuous:
exploit_action = self._actor.sample_action(subjective_state)
action_probabilities = None
else:
assert isinstance(available_action_space, DiscreteActionSpace)
actions = self.action_representation_module(
available_action_space.actions_batch
)
action_probabilities = self._actor.get_policy_distribution(
state_batch=subjective_state,
available_actions=actions,
)
# (action_space_size)
exploit_action = torch.argmax(action_probabilities)
# Step 2: return exploit action if no exploration,
# else pass through the exploration module
if exploit:
return exploit_action
return self._exploration_module.act(
exploit_action=exploit_action,
action_space=available_action_space,
subjective_state=subjective_state,
values=action_probabilities,
)
def reset(self, action_space: ActionSpace) -> None:
self._action_space = action_space
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
self._critic_learn_batch(batch) # update critic
self._actor_learn_batch(batch) # update actor
if self._use_critic_target:
update_critic_target_network(
self._critic_target,
self._critic,
self._use_twin_critic,
self._critic_soft_update_tau,
)
if self._use_actor_target:
update_target_network(
self._actor_target,
self._actor,
self._actor_soft_update_tau,
)
return {}
@abstractmethod
def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
pass
@abstractmethod
def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
pass
def make_critic(
state_dim: int,
hidden_dims: Optional[Iterable[int]],
use_twin_critic: bool,
network_type: Type[QValueNetwork],
action_dim: Optional[int] = None,
) -> nn.Module:
if use_twin_critic:
assert action_dim is not None
assert hidden_dims is not None
return TwinCritic(
state_dim=state_dim,
action_dim=action_dim,
hidden_dims=hidden_dims,
network_type=network_type,
init_fn=init_weights,
)
else:
if network_type == VanillaQValueNetwork:
# pyre-ignore[45]:
# Pyre does not know that `network_type` is asserted to be concrete
return network_type(
state_dim=state_dim,
action_dim=action_dim,
hidden_dims=hidden_dims,
output_dim=1,
)
elif network_type == VanillaValueNetwork:
# pyre-ignore[45]:
# Pyre does not know that `network_type` is asserted to be concrete
return network_type(
input_dim=state_dim, hidden_dims=hidden_dims, output_dim=1
)
else:
raise NotImplementedError(
"Unknown network type. The code needs to be refactored to support this."
)
def update_critic_target_network(
target_network: nn.Module, network: nn.Module, use_twin_critic: bool, tau: float
) -> None:
if use_twin_critic:
update_target_networks(
target_network._critic_networks_combined,
network._critic_networks_combined,
tau=tau,
)
else:
update_target_network(
target_network._model,
network._model,
tau=tau,
)
def single_critic_state_value_update(
state_batch: torch.Tensor,
expected_target_batch: torch.Tensor,
optimizer: torch.optim.Optimizer,
critic: nn.Module,
) -> Dict[str, Any]:
vs = critic(state_batch)
# critic loss
criterion = torch.nn.MSELoss()
loss = criterion(
vs.reshape_as(expected_target_batch), expected_target_batch.detach()
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return {"critic_loss": loss.mean().item()}
def twin_critic_action_value_update(
state_batch: torch.Tensor,
action_batch: torch.Tensor,
expected_target_batch: torch.Tensor,
optimizer: torch.optim.Optimizer,
critic: TwinCritic,
) -> Dict[str, torch.Tensor]:
"""
Performs an optimization step on the twin critic networks.
Args:
state_batch: a batch of states with shape (batch_size, state_dim)
action_batch: a batch of actions with shape (batch_size, action_dim)
expected_target: the batch of target estimates for Bellman equation.
optimizer: the optimizer to use for the update.
critic: the critic network to update.
Returns:
Dict[str, torch.Tensor]: mean loss and individual critic losses.
"""
criterion = torch.nn.MSELoss()
optimizer.zero_grad()
q_1, q_2 = critic.get_q_values(state_batch, action_batch)
loss = criterion(
q_1.reshape_as(expected_target_batch), expected_target_batch.detach()
) + criterion(q_2.reshape_as(expected_target_batch), expected_target_batch.detach())
loss.backward()
optimizer.step()
return {
"critic_mean_loss": loss.item(),
"critic_1_values": q_1.mean().item(),
"critic_2_values": q_2.mean().item(),
}
Functions
def make_critic(state_dim: int, hidden_dims: Optional[Iterable[int]], use_twin_critic: bool, network_type: Type[QValueNetwork], action_dim: Optional[int] = None) ‑> torch.nn.modules.module.Module
-
Expand source code
def make_critic( state_dim: int, hidden_dims: Optional[Iterable[int]], use_twin_critic: bool, network_type: Type[QValueNetwork], action_dim: Optional[int] = None, ) -> nn.Module: if use_twin_critic: assert action_dim is not None assert hidden_dims is not None return TwinCritic( state_dim=state_dim, action_dim=action_dim, hidden_dims=hidden_dims, network_type=network_type, init_fn=init_weights, ) else: if network_type == VanillaQValueNetwork: # pyre-ignore[45]: # Pyre does not know that `network_type` is asserted to be concrete return network_type( state_dim=state_dim, action_dim=action_dim, hidden_dims=hidden_dims, output_dim=1, ) elif network_type == VanillaValueNetwork: # pyre-ignore[45]: # Pyre does not know that `network_type` is asserted to be concrete return network_type( input_dim=state_dim, hidden_dims=hidden_dims, output_dim=1 ) else: raise NotImplementedError( "Unknown network type. The code needs to be refactored to support this." )
def single_critic_state_value_update(state_batch: torch.Tensor, expected_target_batch: torch.Tensor, optimizer: torch.optim.optimizer.Optimizer, critic: torch.nn.modules.module.Module) ‑> Dict[str, Any]
-
Expand source code
def single_critic_state_value_update( state_batch: torch.Tensor, expected_target_batch: torch.Tensor, optimizer: torch.optim.Optimizer, critic: nn.Module, ) -> Dict[str, Any]: vs = critic(state_batch) # critic loss criterion = torch.nn.MSELoss() loss = criterion( vs.reshape_as(expected_target_batch), expected_target_batch.detach() ) optimizer.zero_grad() loss.backward() optimizer.step() return {"critic_loss": loss.mean().item()}
def twin_critic_action_value_update(state_batch: torch.Tensor, action_batch: torch.Tensor, expected_target_batch: torch.Tensor, optimizer: torch.optim.optimizer.Optimizer, critic: TwinCritic) ‑> Dict[str, torch.Tensor]
-
Performs an optimization step on the twin critic networks.
Args
state_batch
- a batch of states with shape (batch_size, state_dim)
action_batch
- a batch of actions with shape (batch_size, action_dim)
expected_target
- the batch of target estimates for Bellman equation.
optimizer
- the optimizer to use for the update.
critic
- the critic network to update.
Returns
Dict[str, torch.Tensor]
- mean loss and individual critic losses.
Expand source code
def twin_critic_action_value_update( state_batch: torch.Tensor, action_batch: torch.Tensor, expected_target_batch: torch.Tensor, optimizer: torch.optim.Optimizer, critic: TwinCritic, ) -> Dict[str, torch.Tensor]: """ Performs an optimization step on the twin critic networks. Args: state_batch: a batch of states with shape (batch_size, state_dim) action_batch: a batch of actions with shape (batch_size, action_dim) expected_target: the batch of target estimates for Bellman equation. optimizer: the optimizer to use for the update. critic: the critic network to update. Returns: Dict[str, torch.Tensor]: mean loss and individual critic losses. """ criterion = torch.nn.MSELoss() optimizer.zero_grad() q_1, q_2 = critic.get_q_values(state_batch, action_batch) loss = criterion( q_1.reshape_as(expected_target_batch), expected_target_batch.detach() ) + criterion(q_2.reshape_as(expected_target_batch), expected_target_batch.detach()) loss.backward() optimizer.step() return { "critic_mean_loss": loss.item(), "critic_1_values": q_1.mean().item(), "critic_2_values": q_2.mean().item(), }
def update_critic_target_network(target_network: torch.nn.modules.module.Module, network: torch.nn.modules.module.Module, use_twin_critic: bool, tau: float) ‑> None
-
Expand source code
def update_critic_target_network( target_network: nn.Module, network: nn.Module, use_twin_critic: bool, tau: float ) -> None: if use_twin_critic: update_target_networks( target_network._critic_networks_combined, network._critic_networks_combined, tau=tau, ) else: update_target_network( target_network._model, network._model, tau=tau, )
Classes
class ActorCriticBase (state_dim: int, exploration_module: ExplorationModule, actor_hidden_dims: List[int], critic_hidden_dims: Optional[List[int]] = None, action_space: Optional[ActionSpace] = None, actor_learning_rate: float = 0.001, critic_learning_rate: float = 0.001, actor_network_type: Type[ActorNetwork] = pearl.neural_networks.sequential_decision_making.actor_networks.VanillaActorNetwork, critic_network_type: Type[QValueNetwork] = pearl.neural_networks.common.value_networks.VanillaQValueNetwork, use_actor_target: bool = False, use_critic_target: bool = False, actor_soft_update_tau: float = 0.005, critic_soft_update_tau: float = 0.005, use_twin_critic: bool = False, discount_factor: float = 0.99, training_rounds: int = 1, batch_size: int = 256, is_action_continuous: bool = False, on_policy: bool = False, action_representation_module: Optional[ActionRepresentationModule] = None)
-
A base class for all actor-critic based policy learners. Many components are common to actor-critic methods. - Actor and critic (as well as target networks) network initializations. - Act, reset and learn_batch methods. - Utility functions used by many actor-critic methods.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class ActorCriticBase(PolicyLearner): """ A base class for all actor-critic based policy learners. Many components are common to actor-critic methods. - Actor and critic (as well as target networks) network initializations. - Act, reset and learn_batch methods. - Utility functions used by many actor-critic methods. """ def __init__( self, state_dim: int, exploration_module: ExplorationModule, actor_hidden_dims: List[int], critic_hidden_dims: Optional[List[int]] = None, action_space: Optional[ActionSpace] = None, actor_learning_rate: float = 1e-3, critic_learning_rate: float = 1e-3, actor_network_type: Type[ActorNetwork] = VanillaActorNetwork, critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork, use_actor_target: bool = False, use_critic_target: bool = False, actor_soft_update_tau: float = 0.005, critic_soft_update_tau: float = 0.005, use_twin_critic: bool = False, discount_factor: float = 0.99, training_rounds: int = 1, batch_size: int = 256, is_action_continuous: bool = False, on_policy: bool = False, action_representation_module: Optional[ActionRepresentationModule] = None, ) -> None: super(ActorCriticBase, self).__init__( on_policy=on_policy, is_action_continuous=is_action_continuous, training_rounds=training_rounds, batch_size=batch_size, exploration_module=exploration_module, action_representation_module=action_representation_module, action_space=action_space, ) self._state_dim = state_dim self._use_actor_target = use_actor_target self._use_critic_target = use_critic_target self._use_twin_critic = use_twin_critic self._use_critic: bool = critic_hidden_dims is not None self._action_dim: int = ( self.action_representation_module.representation_dim if self.is_action_continuous else self.action_representation_module.max_number_actions ) # actor network takes state as input and outputs an action vector self._actor: nn.Module = actor_network_type( input_dim=state_dim + self._action_dim if actor_network_type is DynamicActionActorNetwork else state_dim, hidden_dims=actor_hidden_dims, output_dim=1 if actor_network_type is DynamicActionActorNetwork else self._action_dim, action_space=action_space, ) self._actor.apply(init_weights) self._actor_optimizer = optim.AdamW( [ { "params": self._actor.parameters(), "lr": actor_learning_rate, "amsgrad": True, }, ] ) self._actor_soft_update_tau = actor_soft_update_tau if self._use_actor_target: self._actor_target: nn.Module = actor_network_type( input_dim=state_dim + self._action_dim if actor_network_type is DynamicActionActorNetwork else state_dim, hidden_dims=actor_hidden_dims, output_dim=1 if actor_network_type is DynamicActionActorNetwork else self._action_dim, action_space=action_space, ) update_target_network(self._actor_target, self._actor, tau=1) self._critic_soft_update_tau = critic_soft_update_tau if self._use_critic: self._critic: nn.Module = make_critic( state_dim=self._state_dim, action_dim=self._action_dim, hidden_dims=critic_hidden_dims, use_twin_critic=use_twin_critic, network_type=critic_network_type, ) self._critic_optimizer: optim.Optimizer = optim.AdamW( [ { "params": self._critic.parameters(), "lr": critic_learning_rate, "amsgrad": True, }, ] ) if self._use_critic_target: self._critic_target: nn.Module = make_critic( state_dim=self._state_dim, action_dim=self._action_dim, hidden_dims=critic_hidden_dims, use_twin_critic=use_twin_critic, network_type=critic_network_type, ) update_critic_target_network( self._critic_target, self._critic, use_twin_critic, 1, ) self._discount_factor = discount_factor def set_history_summarization_module( self, value: HistorySummarizationModule ) -> None: self._actor_optimizer.add_param_group({"params": value.parameters()}) if self._use_critic: self._critic_optimizer.add_param_group({"params": value.parameters()}) self._history_summarization_module = value def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, exploit: bool = False, ) -> Action: # Step 1: compute exploit_action # (action computed by actor network; and without any exploration) with torch.no_grad(): if self.is_action_continuous: exploit_action = self._actor.sample_action(subjective_state) action_probabilities = None else: assert isinstance(available_action_space, DiscreteActionSpace) actions = self.action_representation_module( available_action_space.actions_batch ) action_probabilities = self._actor.get_policy_distribution( state_batch=subjective_state, available_actions=actions, ) # (action_space_size) exploit_action = torch.argmax(action_probabilities) # Step 2: return exploit action if no exploration, # else pass through the exploration module if exploit: return exploit_action return self._exploration_module.act( exploit_action=exploit_action, action_space=available_action_space, subjective_state=subjective_state, values=action_probabilities, ) def reset(self, action_space: ActionSpace) -> None: self._action_space = action_space def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: self._critic_learn_batch(batch) # update critic self._actor_learn_batch(batch) # update actor if self._use_critic_target: update_critic_target_network( self._critic_target, self._critic, self._use_twin_critic, self._critic_soft_update_tau, ) if self._use_actor_target: update_target_network( self._actor_target, self._actor, self._actor_soft_update_tau, ) return {} @abstractmethod def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: pass @abstractmethod def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: pass
Ancestors
- PolicyLearner
- torch.nn.modules.module.Module
- abc.ABC
Subclasses
- DeepDeterministicPolicyGradient
- ImplicitQLearning
- ProximalPolicyOptimization
- REINFORCE
- SoftActorCritic
- ContinuousSoftActorCritic
Methods
def act(self, subjective_state: torch.Tensor, available_action_space: ActionSpace, exploit: bool = False) ‑> torch.Tensor
-
Expand source code
def act( self, subjective_state: SubjectiveState, available_action_space: ActionSpace, exploit: bool = False, ) -> Action: # Step 1: compute exploit_action # (action computed by actor network; and without any exploration) with torch.no_grad(): if self.is_action_continuous: exploit_action = self._actor.sample_action(subjective_state) action_probabilities = None else: assert isinstance(available_action_space, DiscreteActionSpace) actions = self.action_representation_module( available_action_space.actions_batch ) action_probabilities = self._actor.get_policy_distribution( state_batch=subjective_state, available_actions=actions, ) # (action_space_size) exploit_action = torch.argmax(action_probabilities) # Step 2: return exploit action if no exploration, # else pass through the exploration module if exploit: return exploit_action return self._exploration_module.act( exploit_action=exploit_action, action_space=available_action_space, subjective_state=subjective_state, values=action_probabilities, )
def set_history_summarization_module(self, value: HistorySummarizationModule) ‑> None
-
Expand source code
def set_history_summarization_module( self, value: HistorySummarizationModule ) -> None: self._actor_optimizer.add_param_group({"params": value.parameters()}) if self._use_critic: self._critic_optimizer.add_param_group({"params": value.parameters()}) self._history_summarization_module = value
Inherited members