Module pearl.safety_modules.reward_constrained_safety_module
Expand source code
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
from typing import Optional
import torch
from pearl.api.action_space import ActionSpace
from pearl.history_summarization_modules.history_summarization_module import (
SubjectiveState,
)
from pearl.policy_learners.policy_learner import PolicyLearner
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.transition import TransitionBatch
from pearl.safety_modules.safety_module import SafetyModule
class RewardConstrainedSafetyModule(SafetyModule):
"""
Placeholder: to be implemented
"""
def __init__(
self,
constraint_value: float,
lambda_constraint_ub_value: float,
lambda_constraint_init_value: float = 0.0,
lr_lambda: float = 0.001,
batch_size: int = 256,
) -> None:
super(RewardConstrainedSafetyModule, self).__init__()
self.constraint_value = constraint_value
self.lr_lambda = lr_lambda
self.lambda_constraint_ub_value = lambda_constraint_ub_value
self.lambda_constraint = lambda_constraint_init_value
self.batch_size = batch_size
self._action_space: Optional[ActionSpace] = None
def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None:
if len(replay_buffer) < self.batch_size or len(replay_buffer) == 0:
return
batch = replay_buffer.sample(self.batch_size)
assert isinstance(batch, TransitionBatch)
self.constraint_lambda_update(batch, policy_learner)
def constraint_lambda_update(
self, batch: TransitionBatch, policy_learner: PolicyLearner
) -> None:
with torch.no_grad():
cost_q1, cost_q2 = policy_learner.cost_critic.get_q_values(
state_batch=batch.state,
action_batch=policy_learner._actor.sample_action(batch.state),
)
cost_q = torch.maximum(cost_q1, cost_q2)
cost_q = cost_q.mean().item()
lambda_update = self.lambda_constraint + self.lr_lambda * (
cost_q * (1 - policy_learner.cost_discount_factor) - self.constraint_value
)
lambda_update = max(lambda_update, 0.0)
lambda_update = min(lambda_update, self.lambda_constraint_ub_value)
self.lambda_constraint = lambda_update
def filter_action(
self, subjective_state: SubjectiveState, action_space: ActionSpace
) -> ActionSpace:
return action_space
def learn_batch(self, batch: TransitionBatch) -> None:
pass
def reset(self, action_space: ActionSpace) -> None:
self._action_space = action_space
Classes
class RewardConstrainedSafetyModule (constraint_value: float, lambda_constraint_ub_value: float, lambda_constraint_init_value: float = 0.0, lr_lambda: float = 0.001, batch_size: int = 256)
-
Placeholder: to be implemented
Expand source code
class RewardConstrainedSafetyModule(SafetyModule): """ Placeholder: to be implemented """ def __init__( self, constraint_value: float, lambda_constraint_ub_value: float, lambda_constraint_init_value: float = 0.0, lr_lambda: float = 0.001, batch_size: int = 256, ) -> None: super(RewardConstrainedSafetyModule, self).__init__() self.constraint_value = constraint_value self.lr_lambda = lr_lambda self.lambda_constraint_ub_value = lambda_constraint_ub_value self.lambda_constraint = lambda_constraint_init_value self.batch_size = batch_size self._action_space: Optional[ActionSpace] = None def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None: if len(replay_buffer) < self.batch_size or len(replay_buffer) == 0: return batch = replay_buffer.sample(self.batch_size) assert isinstance(batch, TransitionBatch) self.constraint_lambda_update(batch, policy_learner) def constraint_lambda_update( self, batch: TransitionBatch, policy_learner: PolicyLearner ) -> None: with torch.no_grad(): cost_q1, cost_q2 = policy_learner.cost_critic.get_q_values( state_batch=batch.state, action_batch=policy_learner._actor.sample_action(batch.state), ) cost_q = torch.maximum(cost_q1, cost_q2) cost_q = cost_q.mean().item() lambda_update = self.lambda_constraint + self.lr_lambda * ( cost_q * (1 - policy_learner.cost_discount_factor) - self.constraint_value ) lambda_update = max(lambda_update, 0.0) lambda_update = min(lambda_update, self.lambda_constraint_ub_value) self.lambda_constraint = lambda_update def filter_action( self, subjective_state: SubjectiveState, action_space: ActionSpace ) -> ActionSpace: return action_space def learn_batch(self, batch: TransitionBatch) -> None: pass def reset(self, action_space: ActionSpace) -> None: self._action_space = action_space
Ancestors
- SafetyModule
- abc.ABC
Methods
def constraint_lambda_update(self, batch: TransitionBatch, policy_learner: PolicyLearner) ‑> None
-
Expand source code
def constraint_lambda_update( self, batch: TransitionBatch, policy_learner: PolicyLearner ) -> None: with torch.no_grad(): cost_q1, cost_q2 = policy_learner.cost_critic.get_q_values( state_batch=batch.state, action_batch=policy_learner._actor.sample_action(batch.state), ) cost_q = torch.maximum(cost_q1, cost_q2) cost_q = cost_q.mean().item() lambda_update = self.lambda_constraint + self.lr_lambda * ( cost_q * (1 - policy_learner.cost_discount_factor) - self.constraint_value ) lambda_update = max(lambda_update, 0.0) lambda_update = min(lambda_update, self.lambda_constraint_ub_value) self.lambda_constraint = lambda_update
def filter_action(self, subjective_state: torch.Tensor, action_space: ActionSpace) ‑> ActionSpace
-
Expand source code
def filter_action( self, subjective_state: SubjectiveState, action_space: ActionSpace ) -> ActionSpace: return action_space
def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) ‑> None
-
Expand source code
def learn(self, replay_buffer: ReplayBuffer, policy_learner: PolicyLearner) -> None: if len(replay_buffer) < self.batch_size or len(replay_buffer) == 0: return batch = replay_buffer.sample(self.batch_size) assert isinstance(batch, TransitionBatch) self.constraint_lambda_update(batch, policy_learner)
def learn_batch(self, batch: TransitionBatch) ‑> None
-
Expand source code
def learn_batch(self, batch: TransitionBatch) -> None: pass
def reset(self, action_space: ActionSpace) ‑> None
-
Expand source code
def reset(self, action_space: ActionSpace) -> None: self._action_space = action_space