Module pearl.policy_learners.sequential_decision_making.bootstrapped_dqn
Expand source code
from copy import deepcopy
from typing import Any, Dict, Optional
import torch
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
)
from pearl.api.action_space import ActionSpace
from pearl.neural_networks.common.utils import update_target_network
from pearl.neural_networks.common.value_networks import EnsembleQValueNetwork
from pearl.policy_learners.exploration_modules.sequential_decision_making.deep_exploration import (
DeepExploration,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.policy_learners.sequential_decision_making.deep_q_learning import (
DeepQLearning,
)
from pearl.replay_buffers.transition import (
filter_batch_by_bootstrap_mask,
TransitionBatch,
TransitionWithBootstrapMaskBatch,
)
from torch import optim, Tensor
class BootstrappedDQN(DeepQLearning):
r"""Bootstrapped DQN, proposed by [1], is an extension of DQN that uses
the so-called "deep exploration" mechanism. The main idea is to keep
an ensemble of `K` Q-value networks and on each episode, one of them is
sampled and the greedy policy associated with that network is used for
exploration.
[1] Ian Osband, Charles Blundell, Alexander Pritzel, and Benjamin
Van Roy, Deep exploration via bootstrapped DQN. Advances in Neural
Information Processing Systems, 2016. https://arxiv.org/abs/1602.04621.
"""
def __init__(
self,
action_space: ActionSpace,
q_ensemble_network: EnsembleQValueNetwork,
discount_factor: float = 0.99,
learning_rate: float = 0.001,
training_rounds: int = 100,
batch_size: int = 128,
target_update_freq: int = 10,
soft_update_tau: float = 1.0,
action_representation_module: Optional[ActionRepresentationModule] = None,
) -> None:
PolicyLearner.__init__(
self=self,
training_rounds=training_rounds,
batch_size=batch_size,
exploration_module=DeepExploration(q_ensemble_network),
on_policy=False,
is_action_continuous=False,
action_representation_module=action_representation_module,
)
self._action_space = action_space
self._learning_rate = learning_rate
self._discount_factor = discount_factor
self._target_update_freq = target_update_freq
self._soft_update_tau = soft_update_tau
self._Q = q_ensemble_network
self._Q_target: EnsembleQValueNetwork = deepcopy(self._Q)
self._optimizer = optim.AdamW(
self._Q.parameters(), lr=self._learning_rate, amsgrad=True
)
@property
def ensemble_size(self) -> int:
return self._Q.ensemble_size
def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
if not isinstance(batch, TransitionWithBootstrapMaskBatch):
raise TypeError(
f"{type(self).__name__} requires a batch of type "
f"`TransitionWithBootstrapMaskBatch`, but got {type(batch)}."
)
loss_ensemble = torch.tensor(0.0).to(batch.device)
mask = batch.bootstrap_mask
for z in range(self.ensemble_size):
z = torch.tensor(z).to(batch.device)
# if this batch doesn't have any items for the z-th ensemble, move on
if mask is None or mask[:, z].sum() == 0:
continue
# filter the batch to only the transitions belonging to ensemble `z`
batch_filtered = filter_batch_by_bootstrap_mask(batch=batch, z=z)
state_action_values = self._Q.get_q_values(
state_batch=batch_filtered.state,
action_batch=batch_filtered.action,
curr_available_actions_batch=batch_filtered.curr_available_actions,
z=z,
)
# compute the Bellman target
expected_state_action_values = (
self._get_next_state_values(
batch=batch_filtered, batch_size=batch_filtered.state.shape[0], z=z
)
* self._discount_factor
* (1 - batch_filtered.done.float())
) + batch_filtered.reward # (batch_size), r + gamma * V(s)
criterion = torch.nn.MSELoss()
loss = criterion(state_action_values, expected_state_action_values)
loss_ensemble += loss
# Optimize the model
self._optimizer.zero_grad()
loss_ensemble.backward()
self._optimizer.step()
# Target Network Update
if (self._training_steps + 1) % self._target_update_freq == 0:
update_target_network(self._Q_target, self._Q, self._soft_update_tau)
return {"loss": loss_ensemble.mean().item()}
def reset(self, action_space: ActionSpace) -> None:
# Reset the `DeepExploration` module, which will resample the epistemic index.
self._exploration_module.reset()
@torch.no_grad()
def _get_next_state_values(
self, batch: TransitionBatch, batch_size: int, z: Optional[Tensor] = None
) -> torch.Tensor:
(
next_state,
next_available_actions,
next_available_actions_mask,
) = self._prepare_next_state_action_batch(batch)
assert next_available_actions is not None
# for dueling, this does a forward pass; since the batch of next available
# actions is already input
# (batch_size x action_space_size)
next_state_action_values = self._Q.get_q_values(
state_batch=next_state, action_batch=next_available_actions, z=z
).view(batch_size, -1)
target_next_state_action_values = self._Q_target.get_q_values(
state_batch=next_state, action_batch=next_available_actions, z=z
).view(batch_size, -1)
# Make sure that unavailable actions' Q values are assigned to -inf
next_state_action_values[next_available_actions_mask] = -float("inf")
# Get argmax actions indices
argmax_actions = next_state_action_values.max(1)[1] # (batch_size)
return target_next_state_action_values[
torch.arange(batch_size), argmax_actions
] # (batch_size)
Classes
class BootstrappedDQN (action_space: ActionSpace, q_ensemble_network: EnsembleQValueNetwork, discount_factor: float = 0.99, learning_rate: float = 0.001, training_rounds: int = 100, batch_size: int = 128, target_update_freq: int = 10, soft_update_tau: float = 1.0, action_representation_module: Optional[ActionRepresentationModule] = None)
-
Bootstrapped DQN, proposed by [1], is an extension of DQN that uses the so-called "deep exploration" mechanism. The main idea is to keep an ensemble of
K
Q-value networks and on each episode, one of them is sampled and the greedy policy associated with that network is used for exploration.[1] Ian Osband, Charles Blundell, Alexander Pritzel, and Benjamin Van Roy, Deep exploration via bootstrapped DQN. Advances in Neural Information Processing Systems, 2016. https://arxiv.org/abs/1602.04621.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class BootstrappedDQN(DeepQLearning): r"""Bootstrapped DQN, proposed by [1], is an extension of DQN that uses the so-called "deep exploration" mechanism. The main idea is to keep an ensemble of `K` Q-value networks and on each episode, one of them is sampled and the greedy policy associated with that network is used for exploration. [1] Ian Osband, Charles Blundell, Alexander Pritzel, and Benjamin Van Roy, Deep exploration via bootstrapped DQN. Advances in Neural Information Processing Systems, 2016. https://arxiv.org/abs/1602.04621. """ def __init__( self, action_space: ActionSpace, q_ensemble_network: EnsembleQValueNetwork, discount_factor: float = 0.99, learning_rate: float = 0.001, training_rounds: int = 100, batch_size: int = 128, target_update_freq: int = 10, soft_update_tau: float = 1.0, action_representation_module: Optional[ActionRepresentationModule] = None, ) -> None: PolicyLearner.__init__( self=self, training_rounds=training_rounds, batch_size=batch_size, exploration_module=DeepExploration(q_ensemble_network), on_policy=False, is_action_continuous=False, action_representation_module=action_representation_module, ) self._action_space = action_space self._learning_rate = learning_rate self._discount_factor = discount_factor self._target_update_freq = target_update_freq self._soft_update_tau = soft_update_tau self._Q = q_ensemble_network self._Q_target: EnsembleQValueNetwork = deepcopy(self._Q) self._optimizer = optim.AdamW( self._Q.parameters(), lr=self._learning_rate, amsgrad=True ) @property def ensemble_size(self) -> int: return self._Q.ensemble_size def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: if not isinstance(batch, TransitionWithBootstrapMaskBatch): raise TypeError( f"{type(self).__name__} requires a batch of type " f"`TransitionWithBootstrapMaskBatch`, but got {type(batch)}." ) loss_ensemble = torch.tensor(0.0).to(batch.device) mask = batch.bootstrap_mask for z in range(self.ensemble_size): z = torch.tensor(z).to(batch.device) # if this batch doesn't have any items for the z-th ensemble, move on if mask is None or mask[:, z].sum() == 0: continue # filter the batch to only the transitions belonging to ensemble `z` batch_filtered = filter_batch_by_bootstrap_mask(batch=batch, z=z) state_action_values = self._Q.get_q_values( state_batch=batch_filtered.state, action_batch=batch_filtered.action, curr_available_actions_batch=batch_filtered.curr_available_actions, z=z, ) # compute the Bellman target expected_state_action_values = ( self._get_next_state_values( batch=batch_filtered, batch_size=batch_filtered.state.shape[0], z=z ) * self._discount_factor * (1 - batch_filtered.done.float()) ) + batch_filtered.reward # (batch_size), r + gamma * V(s) criterion = torch.nn.MSELoss() loss = criterion(state_action_values, expected_state_action_values) loss_ensemble += loss # Optimize the model self._optimizer.zero_grad() loss_ensemble.backward() self._optimizer.step() # Target Network Update if (self._training_steps + 1) % self._target_update_freq == 0: update_target_network(self._Q_target, self._Q, self._soft_update_tau) return {"loss": loss_ensemble.mean().item()} def reset(self, action_space: ActionSpace) -> None: # Reset the `DeepExploration` module, which will resample the epistemic index. self._exploration_module.reset() @torch.no_grad() def _get_next_state_values( self, batch: TransitionBatch, batch_size: int, z: Optional[Tensor] = None ) -> torch.Tensor: ( next_state, next_available_actions, next_available_actions_mask, ) = self._prepare_next_state_action_batch(batch) assert next_available_actions is not None # for dueling, this does a forward pass; since the batch of next available # actions is already input # (batch_size x action_space_size) next_state_action_values = self._Q.get_q_values( state_batch=next_state, action_batch=next_available_actions, z=z ).view(batch_size, -1) target_next_state_action_values = self._Q_target.get_q_values( state_batch=next_state, action_batch=next_available_actions, z=z ).view(batch_size, -1) # Make sure that unavailable actions' Q values are assigned to -inf next_state_action_values[next_available_actions_mask] = -float("inf") # Get argmax actions indices argmax_actions = next_state_action_values.max(1)[1] # (batch_size) return target_next_state_action_values[ torch.arange(batch_size), argmax_actions ] # (batch_size)
Ancestors
- DeepQLearning
- DeepTDLearning
- PolicyLearner
- torch.nn.modules.module.Module
- abc.ABC
Instance variables
var ensemble_size : int
-
Expand source code
@property def ensemble_size(self) -> int: return self._Q.ensemble_size
Inherited members