Module pearl.policy_learners.sequential_decision_making.implicit_q_learning
Expand source code
from typing import Any, Dict, List, Optional, Type
import torch
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
)
from pearl.api.action_space import ActionSpace
from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
)
from pearl.neural_networks.common.utils import update_target_networks
from pearl.neural_networks.common.value_networks import (
QValueNetwork,
ValueNetwork,
VanillaQValueNetwork,
VanillaValueNetwork,
)
from pearl.neural_networks.sequential_decision_making.actor_networks import (
ActorNetwork,
GaussianActorNetwork,
VanillaActorNetwork,
VanillaContinuousActorNetwork,
)
from pearl.neural_networks.sequential_decision_making.twin_critic import TwinCritic
from pearl.policy_learners.exploration_modules.common.no_exploration import (
NoExploration,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.sequential_decision_making.actor_critic_base import (
ActorCriticBase,
twin_critic_action_value_update,
)
from pearl.replay_buffers.transition import TransitionBatch
from torch import optim
class ImplicitQLearning(ActorCriticBase):
"""
Implementation of Implicit Q learning, an offline RL algorithm:
https://arxiv.org/pdf/2110.06169.pdf.
Author implementation in Jax: https://github.com/ikostrikov/implicit_q_learning
Algorithm implementation:
- perform value, crtic and actor updates sequentially
- soft update target networks of twin critics using (tau)
Notes:
1) Currently written for discrete action spaces. For continuous action spaces, we
need to implement the reparameterization trick.
2) This implementation uses twin critic (clipped double q learning) to reduce
overestimation bias. See TwinCritic class for implementation details.
Args:
expectile: a value between 0 and 1, for expectile regression
temperature_advantage_weighted_regression: temperature parameter for advantage
weighted regression; used to extract policy from trained value and critic networks.
"""
def __init__(
self,
state_dim: int,
action_space: ActionSpace,
actor_hidden_dims: List[int],
critic_hidden_dims: List[int],
value_critic_hidden_dims: List[int],
exploration_module: Optional[ExplorationModule] = None,
actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
value_network_type: Type[ValueNetwork] = VanillaValueNetwork,
value_critic_learning_rate: float = 1e-3,
actor_learning_rate: float = 1e-3,
critic_learning_rate: float = 1e-3,
critic_soft_update_tau: float = 0.05,
discount_factor: float = 0.99,
training_rounds: int = 5,
batch_size: int = 128,
expectile: float = 0.5,
temperature_advantage_weighted_regression: float = 0.5,
advantage_clamp: float = 100.0,
action_representation_module: Optional[ActionRepresentationModule] = None,
) -> None:
super(ImplicitQLearning, self).__init__(
state_dim=state_dim,
action_space=action_space,
actor_hidden_dims=actor_hidden_dims,
critic_hidden_dims=critic_hidden_dims,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
actor_network_type=actor_network_type,
critic_network_type=critic_network_type,
use_actor_target=False,
use_critic_target=True,
critic_soft_update_tau=critic_soft_update_tau,
use_twin_critic=True,
exploration_module=exploration_module
if exploration_module is not None
else NoExploration(),
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=batch_size,
is_action_continuous=action_space.is_continuous, # inferred from the action space
on_policy=False,
action_representation_module=action_representation_module,
)
self._expectile = expectile
self._is_action_continuous: bool = action_space.is_continuous
# TODO: base actor networks on a base class, and differentiate between
# discrete and continuous actor networks, as well as stocahstic and deterministic actors
if self._is_action_continuous:
torch._assert(
actor_network_type == GaussianActorNetwork
or actor_network_type == VanillaContinuousActorNetwork,
"continuous action space requires a deterministic or a stochastic actor which works"
"with continuous action spaces",
)
self._temperature_advantage_weighted_regression = (
temperature_advantage_weighted_regression
)
self._advantage_clamp = advantage_clamp
# iql uses both q and v approximators
self._value_network: ValueNetwork = value_network_type(
input_dim=state_dim,
hidden_dims=value_critic_hidden_dims,
output_dim=1,
)
self._value_network_optimizer = optim.AdamW(
self._value_network.parameters(),
lr=value_critic_learning_rate,
amsgrad=True,
)
def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
self._actor_optimizer.add_param_group({"params": value.parameters()})
self._critic_optimizer.add_param_group({"params": value.parameters()})
self._value_network_optimizer.add_param_group({"params": value.parameters()})
self._history_summarization_module = value
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
value_loss = self._value_learn_batch(batch) # update value network
critic_loss = self._critic_learn_batch(batch) # update critic networks
# update critic and target Twin networks;
update_target_networks(
self._critic_target._critic_networks_combined,
self._critic._critic_networks_combined,
self._critic_soft_update_tau,
)
actor_loss = self._actor_learn_batch(batch) # update actor network
return {
"value_loss": value_loss,
"actor_loss": actor_loss,
"critic_loss": critic_loss,
}
def _value_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
with torch.no_grad():
q1, q2 = self._critic_target.get_q_values(batch.state, batch.action)
# random ensemble distillation. TODO: clipped double q-learning
random_index = torch.randint(0, 2, (1,)).item()
target_q = q1 if random_index == 0 else q2 # shape: (batch_size)
value_batch = self._value_network(batch.state).view(-1) # shape: (batch_size)
# note the change in loss function from a mean square loss to an expectile loss
loss_value_network = self._expectile_loss(target_q - value_batch).mean()
self._value_network_optimizer.zero_grad()
loss_value_network.backward()
self._value_network_optimizer.step()
return {"value_loss": loss_value_network.mean().item()}
def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
"""
Performs policy extraction using advantage weighted regression
"""
with torch.no_grad():
q1, q2 = self._critic_target.get_q_values(batch.state, batch.action)
# random ensemble distillation. TODO: clipped double q-learning
random_index = torch.randint(0, 2, (1,)).item()
target_q = q1 if random_index == 0 else q2 # shape: (batch_size)
value_batch = self._value_network(batch.state).view(
-1
) # shape: (batch_size)
advantage = torch.exp(
(target_q - value_batch)
* self._temperature_advantage_weighted_regression
) # shape: (batch_size)
advantage = torch.clamp(advantage, max=self._advantage_clamp)
# TODO: replace VanillaContinuousActorNetwork by a base class for
# deterministic actors
if isinstance(self._actor, VanillaContinuousActorNetwork):
# mean square error between the actor network output and action batch
loss = (
(self._actor.sample_action(batch.state) - batch.action)
.pow(2)
.mean(dim=1)
) # shape: (batch_size)
# advantage weighted regression loss for training deterministic actors
actor_loss = (advantage * loss).mean()
else:
if self.is_action_continuous:
log_action_probabilities = self._actor.get_log_probability(
batch.state, batch.action
).view(
-1
) # shape: (batch_size)
else:
action_probabilities = self._actor(
batch.state
) # shape: (batch_size, action_space_size)
# one_hot to action indices
action_idx = torch.argmax(batch.action, dim=1).unsqueeze(-1)
# gather log probabilities of actions in the dataset
log_action_probabilities = torch.log(
torch.gather(action_probabilities, 1, action_idx).view(-1)
)
# advantage weighted regression for stochastic actors
actor_loss = -(advantage * log_action_probabilities).mean()
self._actor_optimizer.zero_grad()
actor_loss.backward()
self._actor_optimizer.step()
return actor_loss.mean().item()
def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
with torch.no_grad():
# sample values of next states
values_next_states = self._value_network(batch.next_state).view(-1)
# shape: (batch_size)
# To do: add interface to vanilla value networks
# like vanilla q value networks using the 'get' function
# compute targets for batch of (state, action, next_state): target y = r + gamma * V(s')
target = (
values_next_states * self._discount_factor * (1 - batch.done.float())
) + batch.reward # shape: (batch_size)
assert isinstance(
self._critic, TwinCritic
), "Critic in ImplicitQLearning should be TwinCritic"
# update twin critics towards target
loss_critic_update = twin_critic_action_value_update(
state_batch=batch.state,
action_batch=batch.action,
expected_target_batch=target,
optimizer=self._critic_optimizer,
critic=self._critic,
)
return loss_critic_update
# we do not expect this method to be reused in different algorithms, so it is defined here
# To Do: add a utils method separately if needed in future for other algorithms to reuse
def _expectile_loss(self, input_loss: torch.Tensor) -> torch.Tensor:
"""
Expectile loss applies an asymmetric weight
to the input loss function parameterized by self._expectile.
"""
weight = torch.where(input_loss > 0, self._expectile, (1 - self._expectile))
return weight * (input_loss.pow(2))
Classes
class ImplicitQLearning (state_dim: int, action_space: ActionSpace, actor_hidden_dims: List[int], critic_hidden_dims: List[int], value_critic_hidden_dims: List[int], exploration_module: Optional[ExplorationModule] = None, actor_network_type: Type[ActorNetwork] = pearl.neural_networks.sequential_decision_making.actor_networks.VanillaActorNetwork, critic_network_type: Type[QValueNetwork] = pearl.neural_networks.common.value_networks.VanillaQValueNetwork, value_network_type: Type[ValueNetwork] = pearl.neural_networks.common.value_networks.VanillaValueNetwork, value_critic_learning_rate: float = 0.001, actor_learning_rate: float = 0.001, critic_learning_rate: float = 0.001, critic_soft_update_tau: float = 0.05, discount_factor: float = 0.99, training_rounds: int = 5, batch_size: int = 128, expectile: float = 0.5, temperature_advantage_weighted_regression: float = 0.5, advantage_clamp: float = 100.0, action_representation_module: Optional[ActionRepresentationModule] = None)
-
Implementation of Implicit Q learning, an offline RL algorithm: https://arxiv.org/pdf/2110.06169.pdf. Author implementation in Jax: https://github.com/ikostrikov/implicit_q_learning
Algorithm implementation: - perform value, crtic and actor updates sequentially - soft update target networks of twin critics using (tau)
Notes: 1) Currently written for discrete action spaces. For continuous action spaces, we need to implement the reparameterization trick. 2) This implementation uses twin critic (clipped double q learning) to reduce overestimation bias. See TwinCritic class for implementation details.
Args
expectile
- a value between 0 and 1, for expectile regression
temperature_advantage_weighted_regression
- temperature parameter for advantage
weighted regression; used to extract policy from trained value and critic networks. Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class ImplicitQLearning(ActorCriticBase): """ Implementation of Implicit Q learning, an offline RL algorithm: https://arxiv.org/pdf/2110.06169.pdf. Author implementation in Jax: https://github.com/ikostrikov/implicit_q_learning Algorithm implementation: - perform value, crtic and actor updates sequentially - soft update target networks of twin critics using (tau) Notes: 1) Currently written for discrete action spaces. For continuous action spaces, we need to implement the reparameterization trick. 2) This implementation uses twin critic (clipped double q learning) to reduce overestimation bias. See TwinCritic class for implementation details. Args: expectile: a value between 0 and 1, for expectile regression temperature_advantage_weighted_regression: temperature parameter for advantage weighted regression; used to extract policy from trained value and critic networks. """ def __init__( self, state_dim: int, action_space: ActionSpace, actor_hidden_dims: List[int], critic_hidden_dims: List[int], value_critic_hidden_dims: List[int], exploration_module: Optional[ExplorationModule] = None, actor_network_type: Type[ActorNetwork] = VanillaActorNetwork, critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork, value_network_type: Type[ValueNetwork] = VanillaValueNetwork, value_critic_learning_rate: float = 1e-3, actor_learning_rate: float = 1e-3, critic_learning_rate: float = 1e-3, critic_soft_update_tau: float = 0.05, discount_factor: float = 0.99, training_rounds: int = 5, batch_size: int = 128, expectile: float = 0.5, temperature_advantage_weighted_regression: float = 0.5, advantage_clamp: float = 100.0, action_representation_module: Optional[ActionRepresentationModule] = None, ) -> None: super(ImplicitQLearning, self).__init__( state_dim=state_dim, action_space=action_space, actor_hidden_dims=actor_hidden_dims, critic_hidden_dims=critic_hidden_dims, actor_learning_rate=actor_learning_rate, critic_learning_rate=critic_learning_rate, actor_network_type=actor_network_type, critic_network_type=critic_network_type, use_actor_target=False, use_critic_target=True, critic_soft_update_tau=critic_soft_update_tau, use_twin_critic=True, exploration_module=exploration_module if exploration_module is not None else NoExploration(), discount_factor=discount_factor, training_rounds=training_rounds, batch_size=batch_size, is_action_continuous=action_space.is_continuous, # inferred from the action space on_policy=False, action_representation_module=action_representation_module, ) self._expectile = expectile self._is_action_continuous: bool = action_space.is_continuous # TODO: base actor networks on a base class, and differentiate between # discrete and continuous actor networks, as well as stocahstic and deterministic actors if self._is_action_continuous: torch._assert( actor_network_type == GaussianActorNetwork or actor_network_type == VanillaContinuousActorNetwork, "continuous action space requires a deterministic or a stochastic actor which works" "with continuous action spaces", ) self._temperature_advantage_weighted_regression = ( temperature_advantage_weighted_regression ) self._advantage_clamp = advantage_clamp # iql uses both q and v approximators self._value_network: ValueNetwork = value_network_type( input_dim=state_dim, hidden_dims=value_critic_hidden_dims, output_dim=1, ) self._value_network_optimizer = optim.AdamW( self._value_network.parameters(), lr=value_critic_learning_rate, amsgrad=True, ) def set_history_summarization_module( self, value: HistorySummarizationModule ) -> None: self._actor_optimizer.add_param_group({"params": value.parameters()}) self._critic_optimizer.add_param_group({"params": value.parameters()}) self._value_network_optimizer.add_param_group({"params": value.parameters()}) self._history_summarization_module = value def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: value_loss = self._value_learn_batch(batch) # update value network critic_loss = self._critic_learn_batch(batch) # update critic networks # update critic and target Twin networks; update_target_networks( self._critic_target._critic_networks_combined, self._critic._critic_networks_combined, self._critic_soft_update_tau, ) actor_loss = self._actor_learn_batch(batch) # update actor network return { "value_loss": value_loss, "actor_loss": actor_loss, "critic_loss": critic_loss, } def _value_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: with torch.no_grad(): q1, q2 = self._critic_target.get_q_values(batch.state, batch.action) # random ensemble distillation. TODO: clipped double q-learning random_index = torch.randint(0, 2, (1,)).item() target_q = q1 if random_index == 0 else q2 # shape: (batch_size) value_batch = self._value_network(batch.state).view(-1) # shape: (batch_size) # note the change in loss function from a mean square loss to an expectile loss loss_value_network = self._expectile_loss(target_q - value_batch).mean() self._value_network_optimizer.zero_grad() loss_value_network.backward() self._value_network_optimizer.step() return {"value_loss": loss_value_network.mean().item()} def _actor_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: """ Performs policy extraction using advantage weighted regression """ with torch.no_grad(): q1, q2 = self._critic_target.get_q_values(batch.state, batch.action) # random ensemble distillation. TODO: clipped double q-learning random_index = torch.randint(0, 2, (1,)).item() target_q = q1 if random_index == 0 else q2 # shape: (batch_size) value_batch = self._value_network(batch.state).view( -1 ) # shape: (batch_size) advantage = torch.exp( (target_q - value_batch) * self._temperature_advantage_weighted_regression ) # shape: (batch_size) advantage = torch.clamp(advantage, max=self._advantage_clamp) # TODO: replace VanillaContinuousActorNetwork by a base class for # deterministic actors if isinstance(self._actor, VanillaContinuousActorNetwork): # mean square error between the actor network output and action batch loss = ( (self._actor.sample_action(batch.state) - batch.action) .pow(2) .mean(dim=1) ) # shape: (batch_size) # advantage weighted regression loss for training deterministic actors actor_loss = (advantage * loss).mean() else: if self.is_action_continuous: log_action_probabilities = self._actor.get_log_probability( batch.state, batch.action ).view( -1 ) # shape: (batch_size) else: action_probabilities = self._actor( batch.state ) # shape: (batch_size, action_space_size) # one_hot to action indices action_idx = torch.argmax(batch.action, dim=1).unsqueeze(-1) # gather log probabilities of actions in the dataset log_action_probabilities = torch.log( torch.gather(action_probabilities, 1, action_idx).view(-1) ) # advantage weighted regression for stochastic actors actor_loss = -(advantage * log_action_probabilities).mean() self._actor_optimizer.zero_grad() actor_loss.backward() self._actor_optimizer.step() return actor_loss.mean().item() def _critic_learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: with torch.no_grad(): # sample values of next states values_next_states = self._value_network(batch.next_state).view(-1) # shape: (batch_size) # To do: add interface to vanilla value networks # like vanilla q value networks using the 'get' function # compute targets for batch of (state, action, next_state): target y = r + gamma * V(s') target = ( values_next_states * self._discount_factor * (1 - batch.done.float()) ) + batch.reward # shape: (batch_size) assert isinstance( self._critic, TwinCritic ), "Critic in ImplicitQLearning should be TwinCritic" # update twin critics towards target loss_critic_update = twin_critic_action_value_update( state_batch=batch.state, action_batch=batch.action, expected_target_batch=target, optimizer=self._critic_optimizer, critic=self._critic, ) return loss_critic_update # we do not expect this method to be reused in different algorithms, so it is defined here # To Do: add a utils method separately if needed in future for other algorithms to reuse def _expectile_loss(self, input_loss: torch.Tensor) -> torch.Tensor: """ Expectile loss applies an asymmetric weight to the input loss function parameterized by self._expectile. """ weight = torch.where(input_loss > 0, self._expectile, (1 - self._expectile)) return weight * (input_loss.pow(2))
Ancestors
- ActorCriticBase
- PolicyLearner
- torch.nn.modules.module.Module
- abc.ABC
Methods
def set_history_summarization_module(self, value: HistorySummarizationModule) ‑> None
-
Expand source code
def set_history_summarization_module( self, value: HistorySummarizationModule ) -> None: self._actor_optimizer.add_param_group({"params": value.parameters()}) self._critic_optimizer.add_param_group({"params": value.parameters()}) self._value_network_optimizer.add_param_group({"params": value.parameters()}) self._history_summarization_module = value
Inherited members