Module pearl.policy_learners.sequential_decision_making.reinforce

Expand source code
from typing import Any, Dict, List, Optional, Type

from pearl.action_representation_modules.action_representation_module import (
    ActionRepresentationModule,
)

from pearl.neural_networks.common.value_networks import ValueNetwork

from pearl.neural_networks.sequential_decision_making.actor_networks import ActorNetwork

try:
    import gymnasium as gym
except ModuleNotFoundError:
    import gym  # noqa
import torch

from pearl.api.action_space import ActionSpace
from pearl.neural_networks.common.value_networks import VanillaValueNetwork
from pearl.neural_networks.sequential_decision_making.actor_networks import (
    VanillaActorNetwork,
)
from pearl.policy_learners.exploration_modules.common.propensity_exploration import (
    PropensityExploration,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
    ExplorationModule,
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
    ActorCriticBase,
    single_critic_state_value_update,
)
from pearl.replay_buffers.transition import TransitionBatch


class REINFORCE(ActorCriticBase):
    """
    Williams, R. J. (1992). Simple statistical gradient-following algorithms
    for connectionist reinforcement learning. Machine learning, 8, 229-256.
    The critic serves as the baseline.
    """

    def __init__(
        self,
        state_dim: int,
        actor_hidden_dims: List[int],
        critic_hidden_dims: Optional[List[int]] = None,
        action_space: Optional[ActionSpace] = None,
        actor_learning_rate: float = 1e-4,
        critic_learning_rate: float = 1e-4,
        actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
        critic_network_type: Type[ValueNetwork] = VanillaValueNetwork,
        exploration_module: Optional[ExplorationModule] = None,
        discount_factor: float = 0.99,
        training_rounds: int = 1,
        action_representation_module: Optional[ActionRepresentationModule] = None,
    ) -> None:
        super(REINFORCE, self).__init__(
            state_dim=state_dim,
            action_space=action_space,
            actor_hidden_dims=actor_hidden_dims,
            critic_hidden_dims=critic_hidden_dims,
            actor_learning_rate=actor_learning_rate,
            critic_learning_rate=critic_learning_rate,
            actor_network_type=actor_network_type,
            # pyre-fixme: super class expects a QValueNetwork here,
            # but this class apparently requires a ValueNetwork
            # (replacing the type and default value to QValueNetworks break tests)
            critic_network_type=critic_network_type,
            use_actor_target=False,
            use_critic_target=False,
            actor_soft_update_tau=0.0,  # not used
            critic_soft_update_tau=0.0,  # not used
            use_twin_critic=False,
            exploration_module=exploration_module
            if exploration_module is not None
            else PropensityExploration(),
            discount_factor=discount_factor,
            training_rounds=training_rounds,
            batch_size=0,  # REINFORCE does not use batch size
            is_action_continuous=False,
            on_policy=True,
            action_representation_module=action_representation_module,
        )

    def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        state_batch = (
            batch.state
        )  # (batch_size x state_dim) note that here batch_size = episode length
        return_batch = batch.cum_reward  # (batch_size)
        policy_propensities = self._actor.get_action_prob(
            batch.state,
            batch.action,
            batch.curr_available_actions,
            batch.curr_unavailable_actions_mask,
        )  # shape (batch_size)
        negative_log_probs = -torch.log(policy_propensities + 1e-8)
        if self._use_critic:
            v = self._critic(state_batch).view(-1)  # (batch_size)
            assert return_batch is not None
            loss = torch.sum(negative_log_probs * (return_batch - v.detach()))
        else:
            loss = torch.sum(negative_log_probs * return_batch)
        self._actor_optimizer.zero_grad()
        loss.backward()
        self._actor_optimizer.step()

        return {"actor_loss": loss.mean().item()}

    def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        if self._use_critic:
            assert batch.cum_reward is not None
            return single_critic_state_value_update(
                state_batch=batch.state,
                expected_target_batch=batch.cum_reward,
                optimizer=self._critic_optimizer,
                critic=self._critic,
            )
        return {}

Classes

class REINFORCE (state_dim: int, actor_hidden_dims: List[int], critic_hidden_dims: Optional[List[int]] = None, action_space: Optional[ActionSpace] = None, actor_learning_rate: float = 0.0001, critic_learning_rate: float = 0.0001, actor_network_type: Type[ActorNetwork] = pearl.neural_networks.sequential_decision_making.actor_networks.VanillaActorNetwork, critic_network_type: Type[ValueNetwork] = pearl.neural_networks.common.value_networks.VanillaValueNetwork, exploration_module: Optional[ExplorationModule] = None, discount_factor: float = 0.99, training_rounds: int = 1, action_representation_module: Optional[ActionRepresentationModule] = None)

Williams, R. J. (1992). Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8, 229-256. The critic serves as the baseline.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class REINFORCE(ActorCriticBase):
    """
    Williams, R. J. (1992). Simple statistical gradient-following algorithms
    for connectionist reinforcement learning. Machine learning, 8, 229-256.
    The critic serves as the baseline.
    """

    def __init__(
        self,
        state_dim: int,
        actor_hidden_dims: List[int],
        critic_hidden_dims: Optional[List[int]] = None,
        action_space: Optional[ActionSpace] = None,
        actor_learning_rate: float = 1e-4,
        critic_learning_rate: float = 1e-4,
        actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
        critic_network_type: Type[ValueNetwork] = VanillaValueNetwork,
        exploration_module: Optional[ExplorationModule] = None,
        discount_factor: float = 0.99,
        training_rounds: int = 1,
        action_representation_module: Optional[ActionRepresentationModule] = None,
    ) -> None:
        super(REINFORCE, self).__init__(
            state_dim=state_dim,
            action_space=action_space,
            actor_hidden_dims=actor_hidden_dims,
            critic_hidden_dims=critic_hidden_dims,
            actor_learning_rate=actor_learning_rate,
            critic_learning_rate=critic_learning_rate,
            actor_network_type=actor_network_type,
            # pyre-fixme: super class expects a QValueNetwork here,
            # but this class apparently requires a ValueNetwork
            # (replacing the type and default value to QValueNetworks break tests)
            critic_network_type=critic_network_type,
            use_actor_target=False,
            use_critic_target=False,
            actor_soft_update_tau=0.0,  # not used
            critic_soft_update_tau=0.0,  # not used
            use_twin_critic=False,
            exploration_module=exploration_module
            if exploration_module is not None
            else PropensityExploration(),
            discount_factor=discount_factor,
            training_rounds=training_rounds,
            batch_size=0,  # REINFORCE does not use batch size
            is_action_continuous=False,
            on_policy=True,
            action_representation_module=action_representation_module,
        )

    def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        state_batch = (
            batch.state
        )  # (batch_size x state_dim) note that here batch_size = episode length
        return_batch = batch.cum_reward  # (batch_size)
        policy_propensities = self._actor.get_action_prob(
            batch.state,
            batch.action,
            batch.curr_available_actions,
            batch.curr_unavailable_actions_mask,
        )  # shape (batch_size)
        negative_log_probs = -torch.log(policy_propensities + 1e-8)
        if self._use_critic:
            v = self._critic(state_batch).view(-1)  # (batch_size)
            assert return_batch is not None
            loss = torch.sum(negative_log_probs * (return_batch - v.detach()))
        else:
            loss = torch.sum(negative_log_probs * return_batch)
        self._actor_optimizer.zero_grad()
        loss.backward()
        self._actor_optimizer.step()

        return {"actor_loss": loss.mean().item()}

    def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        if self._use_critic:
            assert batch.cum_reward is not None
            return single_critic_state_value_update(
                state_batch=batch.state,
                expected_target_batch=batch.cum_reward,
                optimizer=self._critic_optimizer,
                critic=self._critic,
            )
        return {}

Ancestors

Inherited members