Module pearl.policy_learners.sequential_decision_making.ppo
Expand source code
import copy
from typing import Any, Dict, List, Optional, Type
import torch
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
)
from pearl.api.action_space import ActionSpace
from pearl.neural_networks.common.value_networks import (
ValueNetwork,
VanillaValueNetwork,
)
from pearl.neural_networks.sequential_decision_making.actor_networks import (
ActorNetwork,
VanillaActorNetwork,
)
from pearl.policy_learners.exploration_modules.common.propensity_exploration import (
PropensityExploration,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
ActorCriticBase,
single_critic_state_value_update,
)
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.transition import TransitionBatch
from torch import nn
class ProximalPolicyOptimization(ActorCriticBase):
"""
paper: https://arxiv.org/pdf/1707.06347.pdf
"""
def __init__(
self,
state_dim: int,
action_space: ActionSpace,
actor_hidden_dims: List[int],
critic_hidden_dims: Optional[List[int]],
actor_learning_rate: float = 1e-4,
critic_learning_rate: float = 1e-4,
exploration_module: Optional[ExplorationModule] = None,
actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
critic_network_type: Type[ValueNetwork] = VanillaValueNetwork,
discount_factor: float = 0.99,
training_rounds: int = 100,
batch_size: int = 128,
epsilon: float = 0.0,
entropy_bonus_scaling: float = 0.01,
action_representation_module: Optional[ActionRepresentationModule] = None,
) -> None:
super(ProximalPolicyOptimization, self).__init__(
state_dim=state_dim,
action_space=action_space,
actor_hidden_dims=actor_hidden_dims,
critic_hidden_dims=critic_hidden_dims,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
actor_network_type=actor_network_type,
# pyre-fixme: super class expects a QValueNetwork here,
# but this class apparently requires a ValueNetwork
# (replacing the type and default value to QValueNetworks break tests)
critic_network_type=critic_network_type,
use_actor_target=False,
use_critic_target=False,
actor_soft_update_tau=0.0, # not used
critic_soft_update_tau=0.0, # not used
use_twin_critic=False,
exploration_module=exploration_module
if exploration_module is not None
else PropensityExploration(),
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
is_action_continuous=False,
on_policy=True,
action_representation_module=action_representation_module,
)
self._epsilon = epsilon
self._entropy_bonus_scaling = entropy_bonus_scaling
self._actor_old: nn.Module = copy.deepcopy(self._actor)
def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
"""
Loss = actor loss + critic loss + entropy_bonus_scaling * entropy loss
"""
# TODO: change the output shape of value networks
vs: torch.Tensor = self._critic(batch.state).view(-1) # shape (batch_size)
action_probs = self._actor.get_action_prob(
state_batch=batch.state,
action_batch=batch.action,
available_actions=batch.curr_available_actions,
unavailable_actions_mask=batch.curr_unavailable_actions_mask,
)
# shape (batch_size)
# actor loss
with torch.no_grad():
action_probs_old = self._actor_old.get_action_prob(
state_batch=batch.state,
action_batch=batch.action,
available_actions=batch.curr_available_actions,
unavailable_actions_mask=batch.curr_unavailable_actions_mask,
) # shape (batch_size)
r_thelta = torch.div(action_probs, action_probs_old) # shape (batch_size)
clip = torch.clamp(
r_thelta, min=1.0 - self._epsilon, max=1.0 + self._epsilon
) # shape (batch_size)
# advantage estimator, in paper:
# A = sum(lambda^t*gamma^t*TD_error), while TD_error = reward + gamma * V(s+1) - V(s)
# when lambda = 1 and gamma = 1
# A = sum(TD_error) = return - V(s)
# TODO support lambda and gamma
with torch.no_grad():
advantage = batch.cum_reward - vs # shape (batch_size)
# entropy
# Categorical is good for Cartpole Env where actions are discrete
# TODO need to support continuous action
entropy: torch.Tensor = torch.distributions.Categorical(
action_probs.detach()
).entropy()
loss = torch.sum(
-torch.min(r_thelta * advantage, clip * advantage)
) - torch.sum(self._entropy_bonus_scaling * entropy)
self._actor_optimizer.zero_grad()
loss.backward()
self._actor_optimizer.step()
return {"actor_loss": loss.mean().item()}
def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
assert batch.cum_reward is not None
return single_critic_state_value_update(
state_batch=batch.state,
expected_target_batch=batch.cum_reward,
optimizer=self._critic_optimizer,
critic=self._critic,
)
def learn(self, replay_buffer: ReplayBuffer) -> Dict[str, Any]:
result = super().learn(replay_buffer)
# update old actor with latest actor for next round
self._actor_old.load_state_dict(self._actor.state_dict())
return result
Classes
class ProximalPolicyOptimization (state_dim: int, action_space: ActionSpace, actor_hidden_dims: List[int], critic_hidden_dims: Optional[List[int]], actor_learning_rate: float = 0.0001, critic_learning_rate: float = 0.0001, exploration_module: Optional[ExplorationModule] = None, actor_network_type: Type[ActorNetwork] = pearl.neural_networks.sequential_decision_making.actor_networks.VanillaActorNetwork, critic_network_type: Type[ValueNetwork] = pearl.neural_networks.common.value_networks.VanillaValueNetwork, discount_factor: float = 0.99, training_rounds: int = 100, batch_size: int = 128, epsilon: float = 0.0, entropy_bonus_scaling: float = 0.01, action_representation_module: Optional[ActionRepresentationModule] = None)
-
paper: https://arxiv.org/pdf/1707.06347.pdf
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class ProximalPolicyOptimization(ActorCriticBase): """ paper: https://arxiv.org/pdf/1707.06347.pdf """ def __init__( self, state_dim: int, action_space: ActionSpace, actor_hidden_dims: List[int], critic_hidden_dims: Optional[List[int]], actor_learning_rate: float = 1e-4, critic_learning_rate: float = 1e-4, exploration_module: Optional[ExplorationModule] = None, actor_network_type: Type[ActorNetwork] = VanillaActorNetwork, critic_network_type: Type[ValueNetwork] = VanillaValueNetwork, discount_factor: float = 0.99, training_rounds: int = 100, batch_size: int = 128, epsilon: float = 0.0, entropy_bonus_scaling: float = 0.01, action_representation_module: Optional[ActionRepresentationModule] = None, ) -> None: super(ProximalPolicyOptimization, self).__init__( state_dim=state_dim, action_space=action_space, actor_hidden_dims=actor_hidden_dims, critic_hidden_dims=critic_hidden_dims, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_network_type=actor_network_type, # pyre-fixme: super class expects a QValueNetwork here, # but this class apparently requires a ValueNetwork # (replacing the type and default value to QValueNetworks break tests) critic_network_type=critic_network_type, use_actor_target=False, use_critic_target=False, actor_soft_update_tau=0.0, # not used critic_soft_update_tau=0.0, # not used use_twin_critic=False, exploration_module=exploration_module if exploration_module is not None else PropensityExploration(), discount_factor=discount_factor, training_rounds=training_rounds, batch_size=batch_size, is_action_continuous=False, on_policy=True, action_representation_module=action_representation_module, ) self._epsilon = epsilon self._entropy_bonus_scaling = entropy_bonus_scaling self._actor_old: nn.Module = copy.deepcopy(self._actor) def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: """ Loss = actor loss + critic loss + entropy_bonus_scaling * entropy loss """ # TODO: change the output shape of value networks vs: torch.Tensor = self._critic(batch.state).view(-1) # shape (batch_size) action_probs = self._actor.get_action_prob( state_batch=batch.state, action_batch=batch.action, available_actions=batch.curr_available_actions, unavailable_actions_mask=batch.curr_unavailable_actions_mask, ) # shape (batch_size) # actor loss with torch.no_grad(): action_probs_old = self._actor_old.get_action_prob( state_batch=batch.state, action_batch=batch.action, available_actions=batch.curr_available_actions, unavailable_actions_mask=batch.curr_unavailable_actions_mask, ) # shape (batch_size) r_thelta = torch.div(action_probs, action_probs_old) # shape (batch_size) clip = torch.clamp( r_thelta, min=1.0 - self._epsilon, max=1.0 + self._epsilon ) # shape (batch_size) # advantage estimator, in paper: # A = sum(lambda^t*gamma^t*TD_error), while TD_error = reward + gamma * V(s+1) - V(s) # when lambda = 1 and gamma = 1 # A = sum(TD_error) = return - V(s) # TODO support lambda and gamma with torch.no_grad(): advantage = batch.cum_reward - vs # shape (batch_size) # entropy # Categorical is good for Cartpole Env where actions are discrete # TODO need to support continuous action entropy: torch.Tensor = torch.distributions.Categorical( action_probs.detach() ).entropy() loss = torch.sum( -torch.min(r_thelta * advantage, clip * advantage) ) - torch.sum(self._entropy_bonus_scaling * entropy) self._actor_optimizer.zero_grad() loss.backward() self._actor_optimizer.step() return {"actor_loss": loss.mean().item()} def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: assert batch.cum_reward is not None return single_critic_state_value_update( state_batch=batch.state, expected_target_batch=batch.cum_reward, optimizer=self._critic_optimizer, critic=self._critic, ) def learn(self, replay_buffer: ReplayBuffer) -> Dict[str, Any]: result = super().learn(replay_buffer) # update old actor with latest actor for next round self._actor_old.load_state_dict(self._actor.state_dict()) return result
Ancestors
- ActorCriticBase
- PolicyLearner
- torch.nn.modules.module.Module
- abc.ABC
Inherited members