Module pearl.safety_modules.risk_sensitive_safety_modules
Expand source code
from abc import abstractmethod
from typing import Optional
import torch
from pearl.api.action_space import ActionSpace
from pearl.history_summarization_modules.history_summarization_module import (
SubjectiveState,
)
from pearl.neural_networks.sequential_decision_making.q_value_network import (
DistributionalQValueNetwork,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.transition import TransitionBatch
from pearl.safety_modules.safety_module import SafetyModule
from torch import Tensor
class RiskSensitiveSafetyModule(SafetyModule):
"""
A safety module that computes q values from a q value distribution given a risk measure.
Base class for different risk metrics, e.g. mean-variance, Value-at-risk (VaR) etc.
"""
def filter_action(
self, subjective_state: SubjectiveState, action_space: ActionSpace
) -> ActionSpace:
return action_space
def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None:
pass
def learn_batch(self, batch: TransitionBatch) -> None:
pass
# risk sentitive safe rl methods use this to compute q values from a q value distribution.
@abstractmethod
def get_q_values_under_risk_metric(
self,
state_batch: Tensor,
action_batch: Tensor,
q_value_distribution_network: DistributionalQValueNetwork,
) -> torch.Tensor:
pass
class RiskNeutralSafetyModule(RiskSensitiveSafetyModule):
"""
A safety module that computes q values as expectation of a q value distribution.
"""
def __init__(self) -> None:
super(RiskNeutralSafetyModule, self).__init__()
def __str__(self) -> str:
return f"Safety module type {self.__class__.__name__}"
def get_q_values_under_risk_metric(
self,
state_batch: Tensor,
action_batch: Tensor,
q_value_distribution_network: DistributionalQValueNetwork,
) -> torch.Tensor:
"""Returns Q(s, a), given s and a
Args:
state_batch: a batch of state tensors (batch_size, state_dim)
action_batch: a batch of action tensors (batch_size, action_dim)
q_value_distribution_network: a distributional q value network that
approximates the return distribution
Returns:
Q-values of (state, action) pairs: (batch_size) under a risk neutral measure,
that is, Q(s, a) = E[Z(s, a)]
"""
q_value_distribution = q_value_distribution_network.get_q_value_distribution(
state_batch, action_batch
)
return q_value_distribution.mean(dim=-1)
class QuantileNetworkMeanVarianceSafetyModule(RiskSensitiveSafetyModule):
"""
A safety module that computes q values as a weighted linear combination
of mean and variance of the q value distribution.
Q(s, a) = E[Z(s, a)] - (beta) * Var[Z(s, a)]
"""
def __init__(
self,
variance_weighting_coefficient: float,
) -> None:
super(QuantileNetworkMeanVarianceSafetyModule, self).__init__()
self._beta = variance_weighting_coefficient
def __str__(self) -> str:
return f"Safety module type {self.__class__.__name__}"
def get_q_values_under_risk_metric(
self,
state_batch: Tensor,
action_batch: Tensor,
q_value_distribution_network: DistributionalQValueNetwork,
) -> torch.Tensor:
q_value_distribution = q_value_distribution_network.get_q_value_distribution(
state_batch,
action_batch,
)
"""
variance computation:
- sum_{i=0}^{N-1} (tau_{i+1} - tau_{i}) * (q_value_distribution_{tau_i} - mean_value)^2
"""
mean_value = q_value_distribution.mean(dim=-1, keepdim=True)
quantiles = q_value_distribution_network.quantiles
quantile_differences = quantiles[1:] - quantiles[:-1]
variance = (
quantile_differences * torch.square(q_value_distribution - mean_value)
).sum(dim=-1, keepdim=True)
variance_adjusted_mean = (mean_value - (self._beta * variance)).view(-1)
return variance_adjusted_mean
Classes
class QuantileNetworkMeanVarianceSafetyModule (variance_weighting_coefficient: float)
-
A safety module that computes q values as a weighted linear combination of mean and variance of the q value distribution. Q(s, a) = E[Z(s, a)] - (beta) * Var[Z(s, a)]
Expand source code
class QuantileNetworkMeanVarianceSafetyModule(RiskSensitiveSafetyModule): """ A safety module that computes q values as a weighted linear combination of mean and variance of the q value distribution. Q(s, a) = E[Z(s, a)] - (beta) * Var[Z(s, a)] """ def __init__( self, variance_weighting_coefficient: float, ) -> None: super(QuantileNetworkMeanVarianceSafetyModule, self).__init__() self._beta = variance_weighting_coefficient def __str__(self) -> str: return f"Safety module type {self.__class__.__name__}" def get_q_values_under_risk_metric( self, state_batch: Tensor, action_batch: Tensor, q_value_distribution_network: DistributionalQValueNetwork, ) -> torch.Tensor: q_value_distribution = q_value_distribution_network.get_q_value_distribution( state_batch, action_batch, ) """ variance computation: - sum_{i=0}^{N-1} (tau_{i+1} - tau_{i}) * (q_value_distribution_{tau_i} - mean_value)^2 """ mean_value = q_value_distribution.mean(dim=-1, keepdim=True) quantiles = q_value_distribution_network.quantiles quantile_differences = quantiles[1:] - quantiles[:-1] variance = ( quantile_differences * torch.square(q_value_distribution - mean_value) ).sum(dim=-1, keepdim=True) variance_adjusted_mean = (mean_value - (self._beta * variance)).view(-1) return variance_adjusted_mean
Ancestors
Methods
def get_q_values_under_risk_metric(self, state_batch: torch.Tensor, action_batch: torch.Tensor, q_value_distribution_network: DistributionalQValueNetwork) ‑> torch.Tensor
-
Expand source code
def get_q_values_under_risk_metric( self, state_batch: Tensor, action_batch: Tensor, q_value_distribution_network: DistributionalQValueNetwork, ) -> torch.Tensor: q_value_distribution = q_value_distribution_network.get_q_value_distribution( state_batch, action_batch, ) """ variance computation: - sum_{i=0}^{N-1} (tau_{i+1} - tau_{i}) * (q_value_distribution_{tau_i} - mean_value)^2 """ mean_value = q_value_distribution.mean(dim=-1, keepdim=True) quantiles = q_value_distribution_network.quantiles quantile_differences = quantiles[1:] - quantiles[:-1] variance = ( quantile_differences * torch.square(q_value_distribution - mean_value) ).sum(dim=-1, keepdim=True) variance_adjusted_mean = (mean_value - (self._beta * variance)).view(-1) return variance_adjusted_mean
class RiskNeutralSafetyModule
-
A safety module that computes q values as expectation of a q value distribution.
Expand source code
class RiskNeutralSafetyModule(RiskSensitiveSafetyModule): """ A safety module that computes q values as expectation of a q value distribution. """ def __init__(self) -> None: super(RiskNeutralSafetyModule, self).__init__() def __str__(self) -> str: return f"Safety module type {self.__class__.__name__}" def get_q_values_under_risk_metric( self, state_batch: Tensor, action_batch: Tensor, q_value_distribution_network: DistributionalQValueNetwork, ) -> torch.Tensor: """Returns Q(s, a), given s and a Args: state_batch: a batch of state tensors (batch_size, state_dim) action_batch: a batch of action tensors (batch_size, action_dim) q_value_distribution_network: a distributional q value network that approximates the return distribution Returns: Q-values of (state, action) pairs: (batch_size) under a risk neutral measure, that is, Q(s, a) = E[Z(s, a)] """ q_value_distribution = q_value_distribution_network.get_q_value_distribution( state_batch, action_batch ) return q_value_distribution.mean(dim=-1)
Ancestors
Methods
def get_q_values_under_risk_metric(self, state_batch: torch.Tensor, action_batch: torch.Tensor, q_value_distribution_network: DistributionalQValueNetwork) ‑> torch.Tensor
-
Returns Q(s, a), given s and a
Args
state_batch
- a batch of state tensors (batch_size, state_dim)
action_batch
- a batch of action tensors (batch_size, action_dim)
q_value_distribution_network
- a distributional q value network that approximates the return distribution
Returns
Q-values of (state, action) pairs: (batch_size) under a risk neutral measure, that is, Q(s, a) = E[Z(s, a)]
Expand source code
def get_q_values_under_risk_metric( self, state_batch: Tensor, action_batch: Tensor, q_value_distribution_network: DistributionalQValueNetwork, ) -> torch.Tensor: """Returns Q(s, a), given s and a Args: state_batch: a batch of state tensors (batch_size, state_dim) action_batch: a batch of action tensors (batch_size, action_dim) q_value_distribution_network: a distributional q value network that approximates the return distribution Returns: Q-values of (state, action) pairs: (batch_size) under a risk neutral measure, that is, Q(s, a) = E[Z(s, a)] """ q_value_distribution = q_value_distribution_network.get_q_value_distribution( state_batch, action_batch ) return q_value_distribution.mean(dim=-1)
class RiskSensitiveSafetyModule
-
A safety module that computes q values from a q value distribution given a risk measure. Base class for different risk metrics, e.g. mean-variance, Value-at-risk (VaR) etc.
Expand source code
class RiskSensitiveSafetyModule(SafetyModule): """ A safety module that computes q values from a q value distribution given a risk measure. Base class for different risk metrics, e.g. mean-variance, Value-at-risk (VaR) etc. """ def filter_action( self, subjective_state: SubjectiveState, action_space: ActionSpace ) -> ActionSpace: return action_space def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None: pass def learn_batch(self, batch: TransitionBatch) -> None: pass # risk sentitive safe rl methods use this to compute q values from a q value distribution. @abstractmethod def get_q_values_under_risk_metric( self, state_batch: Tensor, action_batch: Tensor, q_value_distribution_network: DistributionalQValueNetwork, ) -> torch.Tensor: pass
Ancestors
- SafetyModule
- abc.ABC
Subclasses
Methods
def filter_action(self, subjective_state: torch.Tensor, action_space: ActionSpace) ‑> ActionSpace
-
Expand source code
def filter_action( self, subjective_state: SubjectiveState, action_space: ActionSpace ) -> ActionSpace: return action_space
def get_q_values_under_risk_metric(self, state_batch: torch.Tensor, action_batch: torch.Tensor, q_value_distribution_network: DistributionalQValueNetwork) ‑> torch.Tensor
-
Expand source code
@abstractmethod def get_q_values_under_risk_metric( self, state_batch: Tensor, action_batch: Tensor, q_value_distribution_network: DistributionalQValueNetwork, ) -> torch.Tensor: pass
def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) ‑> None
-
Expand source code
def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None: pass
def learn_batch(self, batch: TransitionBatch) ‑> None
-
Expand source code
def learn_batch(self, batch: TransitionBatch) -> None: pass