Module pearl.utils.compatibility_checks
Expand source code
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
from pearl.policy_learners.policy_learner import (
DistributionalPolicyLearner,
PolicyLearner,
)
from pearl.policy_learners.sequential_decision_making.td3 import RCTD3
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.safety_modules.reward_constrained_safety_module import (
RewardConstrainedSafetyModule,
)
from pearl.safety_modules.risk_sensitive_safety_modules import RiskSensitiveSafetyModule
from pearl.safety_modules.safety_module import SafetyModule
def pearl_agent_compatibility_check(
policy_learner: PolicyLearner,
safety_module: SafetyModule,
replay_buffer: ReplayBuffer,
) -> None:
"""
Check if different modules of the Pearl agent are compatible with each other.
"""
if isinstance(policy_learner, DistributionalPolicyLearner):
if not isinstance(safety_module, RiskSensitiveSafetyModule):
raise TypeError(
"A distributional policy learner requires a risk-sensitive safety module."
)
if isinstance(safety_module, RewardConstrainedSafetyModule):
if not isinstance(policy_learner, RCTD3):
raise TypeError(
"An Reward Constrained Policy Optimization safety module requires RCTD3 policy learner."
)
Functions
def pearl_agent_compatibility_check(policy_learner: PolicyLearner, safety_module: SafetyModule, replay_buffer: ReplayBuffer) ‑> None
-
Check if different modules of the Pearl agent are compatible with each other.
Expand source code
def pearl_agent_compatibility_check( policy_learner: PolicyLearner, safety_module: SafetyModule, replay_buffer: ReplayBuffer, ) -> None: """ Check if different modules of the Pearl agent are compatible with each other. """ if isinstance(policy_learner, DistributionalPolicyLearner): if not isinstance(safety_module, RiskSensitiveSafetyModule): raise TypeError( "A distributional policy learner requires a risk-sensitive safety module." ) if isinstance(safety_module, RewardConstrainedSafetyModule): if not isinstance(policy_learner, RCTD3): raise TypeError( "An Reward Constrained Policy Optimization safety module requires RCTD3 policy learner." )