Module pearl.policy_learners.sequential_decision_making.td3

Expand source code
from typing import Any, Dict, List, Optional, Type

import torch
from pearl.action_representation_modules.action_representation_module import (
    ActionRepresentationModule,
)
from pearl.api.action_space import ActionSpace
from pearl.neural_networks.common.utils import update_target_network
from pearl.neural_networks.common.value_networks import VanillaQValueNetwork
from pearl.neural_networks.sequential_decision_making.actor_networks import (
    ActorNetwork,
    VanillaContinuousActorNetwork,
)
from pearl.neural_networks.sequential_decision_making.q_value_network import (
    QValueNetwork,
)
from pearl.neural_networks.sequential_decision_making.twin_critic import TwinCritic
from pearl.policy_learners.exploration_modules.exploration_module import (
    ExplorationModule,
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
    make_critic,
    twin_critic_action_value_update,
    update_critic_target_network,
)
from pearl.policy_learners.sequential_decision_making.ddpg import (
    DeepDeterministicPolicyGradient,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.instantiations.spaces.box_action import BoxActionSpace
from torch import nn, optim


class TD3(DeepDeterministicPolicyGradient):
    """
    TD3 uses a deterministic actor, Twin critics, and a delayed actor update.
        - An exploration module is used with deterministic actors.
        - To avoid exploration, use NoExploration module.
    """

    def __init__(
        self,
        state_dim: int,
        action_space: ActionSpace,
        actor_hidden_dims: List[int],
        critic_hidden_dims: List[int],
        exploration_module: Optional[ExplorationModule] = None,
        actor_learning_rate: float = 1e-3,
        critic_learning_rate: float = 1e-3,
        actor_network_type: Type[ActorNetwork] = VanillaContinuousActorNetwork,
        critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
        actor_soft_update_tau: float = 0.005,
        critic_soft_update_tau: float = 0.005,
        discount_factor: float = 0.99,
        training_rounds: int = 1,
        batch_size: int = 256,
        actor_update_freq: int = 2,
        actor_update_noise: float = 0.2,
        actor_update_noise_clip: float = 0.5,
        action_representation_module: Optional[ActionRepresentationModule] = None,
    ) -> None:
        assert isinstance(action_space, BoxActionSpace)
        super(TD3, self).__init__(
            state_dim=state_dim,
            action_space=action_space,
            exploration_module=exploration_module,
            actor_hidden_dims=actor_hidden_dims,
            critic_hidden_dims=critic_hidden_dims,
            actor_learning_rate=actor_learning_rate,
            critic_learning_rate=critic_learning_rate,
            actor_network_type=actor_network_type,
            critic_network_type=critic_network_type,
            actor_soft_update_tau=actor_soft_update_tau,
            critic_soft_update_tau=critic_soft_update_tau,
            discount_factor=discount_factor,
            training_rounds=training_rounds,
            batch_size=batch_size,
            action_representation_module=action_representation_module,
        )
        self._action_space: BoxActionSpace = action_space
        self._actor_update_freq = actor_update_freq
        self._actor_update_noise = actor_update_noise
        self._actor_update_noise_clip = actor_update_noise_clip
        self._critic_update_count = 0

    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

        self._critic_learn_batch(batch)  # critic update
        self._critic_update_count += 1

        # delayed actor update
        if self._critic_update_count % self._actor_update_freq == 0:
            # see ddpg base class for actor update details
            self._actor_learn_batch(batch)

            # update targets of critics using soft updates
            update_critic_target_network(
                self._critic_target,
                self._critic,
                self._use_twin_critic,
                self._critic_soft_update_tau,
            )
            # update target of actor network using soft updates
            update_target_network(
                self._actor_target, self._actor, self._actor_soft_update_tau
            )

        return {}

    def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

        with torch.no_grad():
            # sample next_action from actor's target network; shape (batch_size, action_dim)
            next_action = self._actor_target.sample_action(batch.next_state)

            # sample clipped gaussian noise
            noise = torch.normal(
                mean=0,
                std=self._actor_update_noise,
                size=next_action.size(),
                device=batch.device,
            )

            noise = torch.clamp(
                noise,
                -self._actor_update_noise_clip,
                self._actor_update_noise_clip,
            )  # shape (batch_size, action_dim)

            # add clipped noise to next_action
            low = torch.tensor(self._action_space.low, device=batch.device)
            high = torch.tensor(self._action_space.high, device=batch.device)

            next_action = torch.clamp(
                next_action + noise, low, high
            )  # shape (batch_size, action_dim)

            # sample q values of (next_state, next_action) from targets of critics
            next_q1, next_q2 = self._critic_target.get_q_values(
                state_batch=batch.next_state,
                action_batch=next_action,
            )  # shape (batch_size)

            # clipped double q learning (reduce overestimation bias)
            next_q = torch.minimum(next_q1, next_q2)

            # compute bellman target:
            # r + gamma * (min{Qtarget_1(s', a from target actor network),
            #                  Qtarget_2(s', a from target actor network)})
            expected_state_action_values = (
                next_q * self._discount_factor * (1 - batch.done.float())
            ) + batch.reward  # (batch_size)

        # update twin critics towards bellman target
        assert isinstance(self._critic, TwinCritic)
        loss_critic_update = twin_critic_action_value_update(
            state_batch=batch.state,
            action_batch=batch.action,
            expected_target_batch=expected_state_action_values,
            optimizer=self._critic_optimizer,
            critic=self._critic,
        )
        return loss_critic_update


class RCTD3(TD3):
    """
    RCTD3 uses TD3 based implementation for reward constraint optimization.
        - An exploration module is used with deterministic actors.
        - To avoid exploration, use NoExploration module.
    """

    def __init__(
        self,
        state_dim: int,
        action_space: ActionSpace,
        actor_hidden_dims: List[int],
        critic_hidden_dims: List[int],
        exploration_module: Optional[ExplorationModule] = None,
        actor_learning_rate: float = 1e-3,
        critic_learning_rate: float = 1e-3,
        actor_network_type: Type[ActorNetwork] = VanillaContinuousActorNetwork,
        critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
        actor_soft_update_tau: float = 0.005,
        critic_soft_update_tau: float = 0.005,
        discount_factor: float = 0.99,
        training_rounds: int = 1,
        batch_size: int = 256,
        actor_update_freq: int = 2,
        actor_update_noise: float = 0.2,
        actor_update_noise_clip: float = 0.5,
        lambda_constraint: float = 1.0,
        cost_discount_factor: float = 0.5,
    ) -> None:
        super(RCTD3, self).__init__(
            state_dim=state_dim,
            action_space=action_space,
            actor_hidden_dims=actor_hidden_dims,
            critic_hidden_dims=critic_hidden_dims,
            exploration_module=exploration_module,
            actor_learning_rate=actor_learning_rate,
            critic_learning_rate=critic_learning_rate,
            actor_network_type=actor_network_type,
            critic_network_type=critic_network_type,
            actor_soft_update_tau=actor_soft_update_tau,
            critic_soft_update_tau=critic_soft_update_tau,
            discount_factor=discount_factor,
            training_rounds=training_rounds,
            batch_size=batch_size,
            actor_update_freq=actor_update_freq,
            actor_update_noise=actor_update_noise,
            actor_update_noise_clip=actor_update_noise_clip,
        )
        self.lambda_constraint = lambda_constraint
        self.cost_discount_factor = cost_discount_factor

        # initialize cost critic
        self.cost_critic: nn.Module = make_critic(
            state_dim=self._state_dim,
            action_dim=self._action_dim,
            hidden_dims=critic_hidden_dims,
            use_twin_critic=self._use_twin_critic,
            network_type=critic_network_type,
        )
        self._cost_critic_optimizer = optim.AdamW(
            [
                {
                    "params": self.cost_critic.parameters(),
                    "lr": critic_learning_rate,
                    "amsgrad": True,
                },
            ]
        )
        self.target_of_cost_critic: nn.Module = make_critic(
            state_dim=self._state_dim,
            action_dim=self._action_dim,
            hidden_dims=critic_hidden_dims,
            use_twin_critic=self._use_twin_critic,
            network_type=critic_network_type,
        )
        update_critic_target_network(
            self.target_of_cost_critic,
            self.cost_critic,
            self._use_twin_critic,
            1,
        )

    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

        # update critics
        self._critic_learn_batch(batch)
        self._critic_update_count += 1

        # update lambda to the current value of safety module
        self.lambda_constraint = self.safety_module.lambda_constraint

        # delayed actor update
        if self._critic_update_count % self._actor_update_freq == 0:
            # see ddpg base class for actor update details
            self._actor_learn_batch(batch)

            # update targets of twin critics using soft updates
            update_critic_target_network(
                self._critic_target,
                self._critic,
                self._use_twin_critic,
                self._critic_soft_update_tau,
            )

            # update targets of cost twin critics using soft updates
            update_critic_target_network(
                self.target_of_cost_critic,
                self.cost_critic,
                self._use_twin_critic,
                self._critic_soft_update_tau,
            )

            # update target of actor network using soft updates
            update_target_network(
                self._actor_target, self._actor, self._actor_soft_update_tau
            )

        return {}

    def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

        # sample a batch of actions from the actor network; shape (batch_size, action_dim)
        action_batch = self._actor.sample_action(batch.state)

        # samples q values for (batch.state, action_batch) from twin critics
        q1, q2 = self._critic.get_q_values(
            state_batch=batch.state, action_batch=action_batch
        )
        # clipped double q learning (reduce overestimation bias); shape (batch_size)
        q = torch.minimum(q1, q2)

        # samples cost q values for (batch.state, action_batch) from twin critics
        cost_q1, cost_q2 = self.cost_critic.get_q_values(
            state_batch=batch.state, action_batch=action_batch
        )
        # clipped double q learning (reduce overestimation bias); shape (batch_size)
        cost_q = torch.maximum(cost_q1, cost_q2)
        # optimization objective: optimize actor to maximize Q(s, a)
        loss = -(q.mean() - self.lambda_constraint * cost_q.mean())

        self._actor_optimizer.zero_grad()
        loss.backward()
        self._actor_optimizer.step()

        return {"actor_loss": loss.mean().item()}

    def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        res = {}
        train_critic_res = self._critic_custom_learn_batch(
            batch,
            critic=self._critic,
            target_of_critic=self._critic_target,
            critic_optimizer=self._critic_optimizer,
            discount_factor=self._discount_factor,
            critic_key="reward",
        )
        # TODO: take max instead of min in the cost critic?
        train_cost_critic_res = self._critic_custom_learn_batch(
            batch,
            critic=self.cost_critic,
            target_of_critic=self.target_of_cost_critic,
            critic_optimizer=self._cost_critic_optimizer,
            discount_factor=self.cost_discount_factor,
            critic_key="cost",
        )
        # modify the keys name
        train_cost_critic_res_new = {}
        for key, value in train_cost_critic_res.items():
            train_cost_critic_res_new["cost_{}".format(key)] = value
        res.update(train_critic_res)
        res.update(train_cost_critic_res_new)
        return res

    def _critic_custom_learn_batch(
        self,
        batch: TransitionBatch,
        critic: nn.Module,
        target_of_critic: nn.Module,
        critic_optimizer: optim.Optimizer,
        discount_factor: float,
        critic_key: str = "reward",
    ) -> Dict[str, Any]:

        assert critic_key in ["reward", "cost"]
        with torch.no_grad():
            # sample next_action from actor's target network; shape (batch_size, action_dim)
            next_action = self._actor_target.sample_action(batch.next_state)

            # sample clipped gaussian noise
            noise = torch.normal(
                mean=0,
                std=self._actor_update_noise,
                size=next_action.size(),
                device=batch.device,
            )

            noise = torch.clamp(
                noise,
                -self._actor_update_noise_clip,
                self._actor_update_noise_clip,
            )  # shape (batch_size, action_dim)

            # add clipped noise to next_action
            low, high = torch.tensor(
                self._action_space.low, device=batch.device
            ), torch.tensor(self._action_space.high, device=batch.device)

            next_action = torch.clamp(
                next_action + noise, low, high
            )  # shape (batch_size, action_dim)

            # sample q values of (next_state, next_action) from targets of twin critics
            next_q1, next_q2 = target_of_critic.get_q_values(
                state_batch=batch.next_state,
                action_batch=next_action,
            )  # shape (batch_size)

            # clipped double q learning (reduce overestimation bias)
            next_q = torch.minimum(next_q1, next_q2)

            # compute bellman target:
            # r + gamma * (min{Qtarget_1(s', a from target actor network), Qtarget_2(s', a from target actor network)}) no-qa
            reward_or_cost = batch.reward if critic_key == "reward" else batch.cost

            expected_state_action_values = (
                next_q * discount_factor * (1 - batch.done.float())
            ) + reward_or_cost  # (batch_size)

        # update critics towards bellman target

        loss_critic_update = twin_critic_action_value_update(
            state_batch=batch.state,
            action_batch=batch.action,
            expected_target_batch=expected_state_action_values,
            optimizer=critic_optimizer,
            # pyre-fixme
            critic=critic,
        )
        return loss_critic_update

Classes

class RCTD3 (state_dim: int, action_space: ActionSpace, actor_hidden_dims: List[int], critic_hidden_dims: List[int], exploration_module: Optional[ExplorationModule] = None, actor_learning_rate: float = 0.001, critic_learning_rate: float = 0.001, actor_network_type: Type[ActorNetwork] = pearl.neural_networks.sequential_decision_making.actor_networks.VanillaContinuousActorNetwork, critic_network_type: Type[QValueNetwork] = pearl.neural_networks.common.value_networks.VanillaQValueNetwork, actor_soft_update_tau: float = 0.005, critic_soft_update_tau: float = 0.005, discount_factor: float = 0.99, training_rounds: int = 1, batch_size: int = 256, actor_update_freq: int = 2, actor_update_noise: float = 0.2, actor_update_noise_clip: float = 0.5, lambda_constraint: float = 1.0, cost_discount_factor: float = 0.5)

RCTD3 uses TD3 based implementation for reward constraint optimization. - An exploration module is used with deterministic actors. - To avoid exploration, use NoExploration module.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class RCTD3(TD3):
    """
    RCTD3 uses TD3 based implementation for reward constraint optimization.
        - An exploration module is used with deterministic actors.
        - To avoid exploration, use NoExploration module.
    """

    def __init__(
        self,
        state_dim: int,
        action_space: ActionSpace,
        actor_hidden_dims: List[int],
        critic_hidden_dims: List[int],
        exploration_module: Optional[ExplorationModule] = None,
        actor_learning_rate: float = 1e-3,
        critic_learning_rate: float = 1e-3,
        actor_network_type: Type[ActorNetwork] = VanillaContinuousActorNetwork,
        critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
        actor_soft_update_tau: float = 0.005,
        critic_soft_update_tau: float = 0.005,
        discount_factor: float = 0.99,
        training_rounds: int = 1,
        batch_size: int = 256,
        actor_update_freq: int = 2,
        actor_update_noise: float = 0.2,
        actor_update_noise_clip: float = 0.5,
        lambda_constraint: float = 1.0,
        cost_discount_factor: float = 0.5,
    ) -> None:
        super(RCTD3, self).__init__(
            state_dim=state_dim,
            action_space=action_space,
            actor_hidden_dims=actor_hidden_dims,
            critic_hidden_dims=critic_hidden_dims,
            exploration_module=exploration_module,
            actor_learning_rate=actor_learning_rate,
            critic_learning_rate=critic_learning_rate,
            actor_network_type=actor_network_type,
            critic_network_type=critic_network_type,
            actor_soft_update_tau=actor_soft_update_tau,
            critic_soft_update_tau=critic_soft_update_tau,
            discount_factor=discount_factor,
            training_rounds=training_rounds,
            batch_size=batch_size,
            actor_update_freq=actor_update_freq,
            actor_update_noise=actor_update_noise,
            actor_update_noise_clip=actor_update_noise_clip,
        )
        self.lambda_constraint = lambda_constraint
        self.cost_discount_factor = cost_discount_factor

        # initialize cost critic
        self.cost_critic: nn.Module = make_critic(
            state_dim=self._state_dim,
            action_dim=self._action_dim,
            hidden_dims=critic_hidden_dims,
            use_twin_critic=self._use_twin_critic,
            network_type=critic_network_type,
        )
        self._cost_critic_optimizer = optim.AdamW(
            [
                {
                    "params": self.cost_critic.parameters(),
                    "lr": critic_learning_rate,
                    "amsgrad": True,
                },
            ]
        )
        self.target_of_cost_critic: nn.Module = make_critic(
            state_dim=self._state_dim,
            action_dim=self._action_dim,
            hidden_dims=critic_hidden_dims,
            use_twin_critic=self._use_twin_critic,
            network_type=critic_network_type,
        )
        update_critic_target_network(
            self.target_of_cost_critic,
            self.cost_critic,
            self._use_twin_critic,
            1,
        )

    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

        # update critics
        self._critic_learn_batch(batch)
        self._critic_update_count += 1

        # update lambda to the current value of safety module
        self.lambda_constraint = self.safety_module.lambda_constraint

        # delayed actor update
        if self._critic_update_count % self._actor_update_freq == 0:
            # see ddpg base class for actor update details
            self._actor_learn_batch(batch)

            # update targets of twin critics using soft updates
            update_critic_target_network(
                self._critic_target,
                self._critic,
                self._use_twin_critic,
                self._critic_soft_update_tau,
            )

            # update targets of cost twin critics using soft updates
            update_critic_target_network(
                self.target_of_cost_critic,
                self.cost_critic,
                self._use_twin_critic,
                self._critic_soft_update_tau,
            )

            # update target of actor network using soft updates
            update_target_network(
                self._actor_target, self._actor, self._actor_soft_update_tau
            )

        return {}

    def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

        # sample a batch of actions from the actor network; shape (batch_size, action_dim)
        action_batch = self._actor.sample_action(batch.state)

        # samples q values for (batch.state, action_batch) from twin critics
        q1, q2 = self._critic.get_q_values(
            state_batch=batch.state, action_batch=action_batch
        )
        # clipped double q learning (reduce overestimation bias); shape (batch_size)
        q = torch.minimum(q1, q2)

        # samples cost q values for (batch.state, action_batch) from twin critics
        cost_q1, cost_q2 = self.cost_critic.get_q_values(
            state_batch=batch.state, action_batch=action_batch
        )
        # clipped double q learning (reduce overestimation bias); shape (batch_size)
        cost_q = torch.maximum(cost_q1, cost_q2)
        # optimization objective: optimize actor to maximize Q(s, a)
        loss = -(q.mean() - self.lambda_constraint * cost_q.mean())

        self._actor_optimizer.zero_grad()
        loss.backward()
        self._actor_optimizer.step()

        return {"actor_loss": loss.mean().item()}

    def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
        res = {}
        train_critic_res = self._critic_custom_learn_batch(
            batch,
            critic=self._critic,
            target_of_critic=self._critic_target,
            critic_optimizer=self._critic_optimizer,
            discount_factor=self._discount_factor,
            critic_key="reward",
        )
        # TODO: take max instead of min in the cost critic?
        train_cost_critic_res = self._critic_custom_learn_batch(
            batch,
            critic=self.cost_critic,
            target_of_critic=self.target_of_cost_critic,
            critic_optimizer=self._cost_critic_optimizer,
            discount_factor=self.cost_discount_factor,
            critic_key="cost",
        )
        # modify the keys name
        train_cost_critic_res_new = {}
        for key, value in train_cost_critic_res.items():
            train_cost_critic_res_new["cost_{}".format(key)] = value
        res.update(train_critic_res)
        res.update(train_cost_critic_res_new)
        return res

    def _critic_custom_learn_batch(
        self,
        batch: TransitionBatch,
        critic: nn.Module,
        target_of_critic: nn.Module,
        critic_optimizer: optim.Optimizer,
        discount_factor: float,
        critic_key: str = "reward",
    ) -> Dict[str, Any]:

        assert critic_key in ["reward", "cost"]
        with torch.no_grad():
            # sample next_action from actor's target network; shape (batch_size, action_dim)
            next_action = self._actor_target.sample_action(batch.next_state)

            # sample clipped gaussian noise
            noise = torch.normal(
                mean=0,
                std=self._actor_update_noise,
                size=next_action.size(),
                device=batch.device,
            )

            noise = torch.clamp(
                noise,
                -self._actor_update_noise_clip,
                self._actor_update_noise_clip,
            )  # shape (batch_size, action_dim)

            # add clipped noise to next_action
            low, high = torch.tensor(
                self._action_space.low, device=batch.device
            ), torch.tensor(self._action_space.high, device=batch.device)

            next_action = torch.clamp(
                next_action + noise, low, high
            )  # shape (batch_size, action_dim)

            # sample q values of (next_state, next_action) from targets of twin critics
            next_q1, next_q2 = target_of_critic.get_q_values(
                state_batch=batch.next_state,
                action_batch=next_action,
            )  # shape (batch_size)

            # clipped double q learning (reduce overestimation bias)
            next_q = torch.minimum(next_q1, next_q2)

            # compute bellman target:
            # r + gamma * (min{Qtarget_1(s', a from target actor network), Qtarget_2(s', a from target actor network)}) no-qa
            reward_or_cost = batch.reward if critic_key == "reward" else batch.cost

            expected_state_action_values = (
                next_q * discount_factor * (1 - batch.done.float())
            ) + reward_or_cost  # (batch_size)

        # update critics towards bellman target

        loss_critic_update = twin_critic_action_value_update(
            state_batch=batch.state,
            action_batch=batch.action,
            expected_target_batch=expected_state_action_values,
            optimizer=critic_optimizer,
            # pyre-fixme
            critic=critic,
        )
        return loss_critic_update

Ancestors

Inherited members

class TD3 (state_dim: int, action_space: ActionSpace, actor_hidden_dims: List[int], critic_hidden_dims: List[int], exploration_module: Optional[ExplorationModule] = None, actor_learning_rate: float = 0.001, critic_learning_rate: float = 0.001, actor_network_type: Type[ActorNetwork] = pearl.neural_networks.sequential_decision_making.actor_networks.VanillaContinuousActorNetwork, critic_network_type: Type[QValueNetwork] = pearl.neural_networks.common.value_networks.VanillaQValueNetwork, actor_soft_update_tau: float = 0.005, critic_soft_update_tau: float = 0.005, discount_factor: float = 0.99, training_rounds: int = 1, batch_size: int = 256, actor_update_freq: int = 2, actor_update_noise: float = 0.2, actor_update_noise_clip: float = 0.5, action_representation_module: Optional[ActionRepresentationModule] = None)

TD3 uses a deterministic actor, Twin critics, and a delayed actor update. - An exploration module is used with deterministic actors. - To avoid exploration, use NoExploration module.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class TD3(DeepDeterministicPolicyGradient):
    """
    TD3 uses a deterministic actor, Twin critics, and a delayed actor update.
        - An exploration module is used with deterministic actors.
        - To avoid exploration, use NoExploration module.
    """

    def __init__(
        self,
        state_dim: int,
        action_space: ActionSpace,
        actor_hidden_dims: List[int],
        critic_hidden_dims: List[int],
        exploration_module: Optional[ExplorationModule] = None,
        actor_learning_rate: float = 1e-3,
        critic_learning_rate: float = 1e-3,
        actor_network_type: Type[ActorNetwork] = VanillaContinuousActorNetwork,
        critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
        actor_soft_update_tau: float = 0.005,
        critic_soft_update_tau: float = 0.005,
        discount_factor: float = 0.99,
        training_rounds: int = 1,
        batch_size: int = 256,
        actor_update_freq: int = 2,
        actor_update_noise: float = 0.2,
        actor_update_noise_clip: float = 0.5,
        action_representation_module: Optional[ActionRepresentationModule] = None,
    ) -> None:
        assert isinstance(action_space, BoxActionSpace)
        super(TD3, self).__init__(
            state_dim=state_dim,
            action_space=action_space,
            exploration_module=exploration_module,
            actor_hidden_dims=actor_hidden_dims,
            critic_hidden_dims=critic_hidden_dims,
            actor_learning_rate=actor_learning_rate,
            critic_learning_rate=critic_learning_rate,
            actor_network_type=actor_network_type,
            critic_network_type=critic_network_type,
            actor_soft_update_tau=actor_soft_update_tau,
            critic_soft_update_tau=critic_soft_update_tau,
            discount_factor=discount_factor,
            training_rounds=training_rounds,
            batch_size=batch_size,
            action_representation_module=action_representation_module,
        )
        self._action_space: BoxActionSpace = action_space
        self._actor_update_freq = actor_update_freq
        self._actor_update_noise = actor_update_noise
        self._actor_update_noise_clip = actor_update_noise_clip
        self._critic_update_count = 0

    def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

        self._critic_learn_batch(batch)  # critic update
        self._critic_update_count += 1

        # delayed actor update
        if self._critic_update_count % self._actor_update_freq == 0:
            # see ddpg base class for actor update details
            self._actor_learn_batch(batch)

            # update targets of critics using soft updates
            update_critic_target_network(
                self._critic_target,
                self._critic,
                self._use_twin_critic,
                self._critic_soft_update_tau,
            )
            # update target of actor network using soft updates
            update_target_network(
                self._actor_target, self._actor, self._actor_soft_update_tau
            )

        return {}

    def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

        with torch.no_grad():
            # sample next_action from actor's target network; shape (batch_size, action_dim)
            next_action = self._actor_target.sample_action(batch.next_state)

            # sample clipped gaussian noise
            noise = torch.normal(
                mean=0,
                std=self._actor_update_noise,
                size=next_action.size(),
                device=batch.device,
            )

            noise = torch.clamp(
                noise,
                -self._actor_update_noise_clip,
                self._actor_update_noise_clip,
            )  # shape (batch_size, action_dim)

            # add clipped noise to next_action
            low = torch.tensor(self._action_space.low, device=batch.device)
            high = torch.tensor(self._action_space.high, device=batch.device)

            next_action = torch.clamp(
                next_action + noise, low, high
            )  # shape (batch_size, action_dim)

            # sample q values of (next_state, next_action) from targets of critics
            next_q1, next_q2 = self._critic_target.get_q_values(
                state_batch=batch.next_state,
                action_batch=next_action,
            )  # shape (batch_size)

            # clipped double q learning (reduce overestimation bias)
            next_q = torch.minimum(next_q1, next_q2)

            # compute bellman target:
            # r + gamma * (min{Qtarget_1(s', a from target actor network),
            #                  Qtarget_2(s', a from target actor network)})
            expected_state_action_values = (
                next_q * self._discount_factor * (1 - batch.done.float())
            ) + batch.reward  # (batch_size)

        # update twin critics towards bellman target
        assert isinstance(self._critic, TwinCritic)
        loss_critic_update = twin_critic_action_value_update(
            state_batch=batch.state,
            action_batch=batch.action,
            expected_target_batch=expected_state_action_values,
            optimizer=self._critic_optimizer,
            critic=self._critic,
        )
        return loss_critic_update

Ancestors

Subclasses

Inherited members