Module pearl.policy_learners.sequential_decision_making.soft_actor_critic
Expand source code
from typing import Any, Dict, List, Optional, Type
import torch
import torch.nn.functional as F
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
)
from pearl.api.action_space import ActionSpace
from pearl.neural_networks.common.value_networks import VanillaQValueNetwork
from pearl.neural_networks.sequential_decision_making.actor_networks import (
ActorNetwork,
VanillaActorNetwork,
)
from pearl.neural_networks.sequential_decision_making.q_value_network import (
QValueNetwork,
)
from pearl.neural_networks.sequential_decision_making.twin_critic import TwinCritic
from pearl.policy_learners.exploration_modules.common.propensity_exploration import (
PropensityExploration,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
ActorCriticBase,
twin_critic_action_value_update,
)
from pearl.replay_buffers.transition import TransitionBatch
from torch import optim
# Currently available actions is not used. Needs to be updated once we know the input
# structure of production stack on this param.
# TODO: to make things easier with a single optimizer, we need to polish this method.
class SoftActorCritic(ActorCriticBase):
"""
Implementation of Soft Actor Critic Policy Learner for discrete action spaces.
"""
def __init__(
self,
state_dim: int,
action_space: ActionSpace,
actor_hidden_dims: List[int],
critic_hidden_dims: List[int],
actor_learning_rate: float = 1e-4,
critic_learning_rate: float = 1e-4,
actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
critic_soft_update_tau: float = 0.005,
exploration_module: Optional[ExplorationModule] = None,
discount_factor: float = 0.99,
training_rounds: int = 100,
batch_size: int = 128,
entropy_coef: float = 0.2,
action_representation_module: Optional[ActionRepresentationModule] = None,
) -> None:
super(SoftActorCritic, self).__init__(
state_dim=state_dim,
action_space=action_space,
actor_hidden_dims=actor_hidden_dims,
critic_hidden_dims=critic_hidden_dims,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
actor_network_type=actor_network_type,
critic_network_type=critic_network_type,
use_actor_target=False,
use_critic_target=True,
actor_soft_update_tau=0.0, # not used
critic_soft_update_tau=critic_soft_update_tau,
use_twin_critic=True,
exploration_module=exploration_module
if exploration_module is not None
else PropensityExploration(),
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
is_action_continuous=False,
on_policy=False,
action_representation_module=action_representation_module,
)
# This is needed to avoid actor softmax overflow issue.
# Should not be left for users to choose.
self.scheduler = optim.lr_scheduler.ExponentialLR(
self._actor_optimizer, gamma=0.99
)
# TODO: implement learnable entropy coefficient
self._entropy_coef = entropy_coef
# sac uses a learning rate scheduler specifically
def reset(self, action_space: ActionSpace) -> None:
self._action_space = action_space
self.scheduler.step()
def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
reward_batch = batch.reward # (batch_size)
done_batch = batch.done # (batch_size)
assert done_batch is not None
expected_state_action_values = (
self._get_next_state_expected_values(batch)
* self._discount_factor
* (1 - done_batch.float())
) + reward_batch # (batch_size), r + gamma * V(s)
assert isinstance(self._critic, TwinCritic)
loss_critic_update = twin_critic_action_value_update(
state_batch=batch.state,
action_batch=batch.action,
expected_target_batch=expected_state_action_values,
optimizer=self._critic_optimizer,
critic=self._critic,
)
return loss_critic_update
@torch.no_grad()
def _get_next_state_expected_values(self, batch: TransitionBatch) -> torch.Tensor:
next_state_batch = batch.next_state # (batch_size x state_dim)
next_available_actions_batch = (
batch.next_available_actions
) # (batch_size x action_space_size x action_dim)
next_unavailable_actions_mask_batch = (
batch.next_unavailable_actions_mask
) # (batch_size x action_space_size)
assert next_state_batch is not None
assert next_available_actions_batch is not None
next_state_batch_repeated = torch.repeat_interleave(
next_state_batch.unsqueeze(1),
self.action_representation_module.max_number_actions,
dim=1,
) # (batch_size x action_space_size x state_dim)
# get q values of (states, all actions) from twin critics
next_q1, next_q2 = self._critic_target.get_q_values(
state_batch=next_state_batch_repeated,
action_batch=next_available_actions_batch,
)
# clipped double q-learning (reduce overestimation bias)
next_q = torch.minimum(next_q1, next_q2)
# random ensemble distillation (reduce overestimation bias)
# random_index = torch.randint(0, 2, (1,)).item()
# next_q = next_q1 if random_index == 0 else next_q2
next_state_action_values = next_q.view(
self.batch_size, -1
) # (batch_size x action_space_size)
# Make sure that unavailable actions' Q values are assigned to 0.0
# since we are calculating expectation
if next_unavailable_actions_mask_batch is not None:
next_state_action_values[next_unavailable_actions_mask_batch] = 0.0
next_state_policy_dist = self._actor.get_policy_distribution(
state_batch=next_state_batch,
available_actions=next_available_actions_batch,
unavailable_actions_mask=next_unavailable_actions_mask_batch,
) # (batch_size x action_space_size)
# Entropy Regularization
next_state_action_values = (
next_state_action_values
- self._entropy_coef * torch.log(next_state_policy_dist + 1e-8)
) * next_state_policy_dist # (batch_size x action_space_size)
return next_state_action_values.sum(dim=1)
def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
state_batch = batch.state # (batch_size x state_dim)
state_batch_repeated = torch.repeat_interleave(
state_batch.unsqueeze(1),
self.action_representation_module.max_number_actions,
dim=1,
) # (batch_size x action_space_size x state_dim)
available_actions = (
batch.curr_available_actions
) # (batch_size x action_space_size x action_dim)
# get q values of (states, all actions) from twin critics
q1, q2 = self._critic.get_q_values(
state_batch=state_batch_repeated, action_batch=available_actions
)
# clipped double q learning (reduce overestimation bias)
q = torch.minimum(q1, q2)
unavailable_actions_mask = (
batch.curr_unavailable_actions_mask
) # (batch_size x action_space_size)
new_policy_dist = self._actor.get_policy_distribution(
state_batch=state_batch,
available_actions=available_actions,
unavailable_actions_mask=unavailable_actions_mask,
) # (batch_size x action_space_size)
state_action_values = q.view(
(self.batch_size, self.action_representation_module.max_number_actions)
) # (batch_size x action_space_size)
if unavailable_actions_mask is not None:
state_action_values[unavailable_actions_mask] = 0.0
policy_loss = (
(
new_policy_dist
* (
self._entropy_coef * torch.log(new_policy_dist + 1e-8)
- state_action_values
)
)
.sum(dim=1)
.mean()
)
self._actor_optimizer.zero_grad()
policy_loss.backward()
self._actor_optimizer.step()
return {"actor_loss": policy_loss.item()}
Classes
class SoftActorCritic (state_dim: int, action_space: ActionSpace, actor_hidden_dims: List[int], critic_hidden_dims: List[int], actor_learning_rate: float = 0.0001, critic_learning_rate: float = 0.0001, actor_network_type: Type[ActorNetwork] = pearl.neural_networks.sequential_decision_making.actor_networks.VanillaActorNetwork, critic_network_type: Type[QValueNetwork] = pearl.neural_networks.common.value_networks.VanillaQValueNetwork, critic_soft_update_tau: float = 0.005, exploration_module: Optional[ExplorationModule] = None, discount_factor: float = 0.99, training_rounds: int = 100, batch_size: int = 128, entropy_coef: float = 0.2, action_representation_module: Optional[ActionRepresentationModule] = None)
-
Implementation of Soft Actor Critic Policy Learner for discrete action spaces.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class SoftActorCritic(ActorCriticBase): """ Implementation of Soft Actor Critic Policy Learner for discrete action spaces. """ def __init__( self, state_dim: int, action_space: ActionSpace, actor_hidden_dims: List[int], critic_hidden_dims: List[int], actor_learning_rate: float = 1e-4, critic_learning_rate: float = 1e-4, actor_network_type: Type[ActorNetwork] = VanillaActorNetwork, critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork, critic_soft_update_tau: float = 0.005, exploration_module: Optional[ExplorationModule] = None, discount_factor: float = 0.99, training_rounds: int = 100, batch_size: int = 128, entropy_coef: float = 0.2, action_representation_module: Optional[ActionRepresentationModule] = None, ) -> None: super(SoftActorCritic, self).__init__( state_dim=state_dim, action_space=action_space, actor_hidden_dims=actor_hidden_dims, critic_hidden_dims=critic_hidden_dims, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_network_type=actor_network_type, critic_network_type=critic_network_type, use_actor_target=False, use_critic_target=True, actor_soft_update_tau=0.0, # not used critic_soft_update_tau=critic_soft_update_tau, use_twin_critic=True, exploration_module=exploration_module if exploration_module is not None else PropensityExploration(), discount_factor=discount_factor, training_rounds=training_rounds, batch_size=batch_size, is_action_continuous=False, on_policy=False, action_representation_module=action_representation_module, ) # This is needed to avoid actor softmax overflow issue. # Should not be left for users to choose. self.scheduler = optim.lr_scheduler.ExponentialLR( self._actor_optimizer, gamma=0.99 ) # TODO: implement learnable entropy coefficient self._entropy_coef = entropy_coef # sac uses a learning rate scheduler specifically def reset(self, action_space: ActionSpace) -> None: self._action_space = action_space self.scheduler.step() def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: reward_batch = batch.reward # (batch_size) done_batch = batch.done # (batch_size) assert done_batch is not None expected_state_action_values = ( self._get_next_state_expected_values(batch) * self._discount_factor * (1 - done_batch.float()) ) + reward_batch # (batch_size), r + gamma * V(s) assert isinstance(self._critic, TwinCritic) loss_critic_update = twin_critic_action_value_update( state_batch=batch.state, action_batch=batch.action, expected_target_batch=expected_state_action_values, optimizer=self._critic_optimizer, critic=self._critic, ) return loss_critic_update @torch.no_grad() def _get_next_state_expected_values(self, batch: TransitionBatch) -> torch.Tensor: next_state_batch = batch.next_state # (batch_size x state_dim) next_available_actions_batch = ( batch.next_available_actions ) # (batch_size x action_space_size x action_dim) next_unavailable_actions_mask_batch = ( batch.next_unavailable_actions_mask ) # (batch_size x action_space_size) assert next_state_batch is not None assert next_available_actions_batch is not None next_state_batch_repeated = torch.repeat_interleave( next_state_batch.unsqueeze(1), self.action_representation_module.max_number_actions, dim=1, ) # (batch_size x action_space_size x state_dim) # get q values of (states, all actions) from twin critics next_q1, next_q2 = self._critic_target.get_q_values( state_batch=next_state_batch_repeated, action_batch=next_available_actions_batch, ) # clipped double q-learning (reduce overestimation bias) next_q = torch.minimum(next_q1, next_q2) # random ensemble distillation (reduce overestimation bias) # random_index = torch.randint(0, 2, (1,)).item() # next_q = next_q1 if random_index == 0 else next_q2 next_state_action_values = next_q.view( self.batch_size, -1 ) # (batch_size x action_space_size) # Make sure that unavailable actions' Q values are assigned to 0.0 # since we are calculating expectation if next_unavailable_actions_mask_batch is not None: next_state_action_values[next_unavailable_actions_mask_batch] = 0.0 next_state_policy_dist = self._actor.get_policy_distribution( state_batch=next_state_batch, available_actions=next_available_actions_batch, unavailable_actions_mask=next_unavailable_actions_mask_batch, ) # (batch_size x action_space_size) # Entropy Regularization next_state_action_values = ( next_state_action_values - self._entropy_coef * torch.log(next_state_policy_dist + 1e-8) ) * next_state_policy_dist # (batch_size x action_space_size) return next_state_action_values.sum(dim=1) def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: state_batch = batch.state # (batch_size x state_dim) state_batch_repeated = torch.repeat_interleave( state_batch.unsqueeze(1), self.action_representation_module.max_number_actions, dim=1, ) # (batch_size x action_space_size x state_dim) available_actions = ( batch.curr_available_actions ) # (batch_size x action_space_size x action_dim) # get q values of (states, all actions) from twin critics q1, q2 = self._critic.get_q_values( state_batch=state_batch_repeated, action_batch=available_actions ) # clipped double q learning (reduce overestimation bias) q = torch.minimum(q1, q2) unavailable_actions_mask = ( batch.curr_unavailable_actions_mask ) # (batch_size x action_space_size) new_policy_dist = self._actor.get_policy_distribution( state_batch=state_batch, available_actions=available_actions, unavailable_actions_mask=unavailable_actions_mask, ) # (batch_size x action_space_size) state_action_values = q.view( (self.batch_size, self.action_representation_module.max_number_actions) ) # (batch_size x action_space_size) if unavailable_actions_mask is not None: state_action_values[unavailable_actions_mask] = 0.0 policy_loss = ( ( new_policy_dist * ( self._entropy_coef * torch.log(new_policy_dist + 1e-8) - state_action_values ) ) .sum(dim=1) .mean() ) self._actor_optimizer.zero_grad() policy_loss.backward() self._actor_optimizer.step() return {"actor_loss": policy_loss.item()}
Ancestors
- ActorCriticBase
- PolicyLearner
- torch.nn.modules.module.Module
- abc.ABC
Inherited members