Module pearl.utils.compatibility_checks

Expand source code
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
from pearl.policy_learners.policy_learner import (
    DistributionalPolicyLearner,
    PolicyLearner,
)
from pearl.policy_learners.sequential_decision_making.td3 import RCTD3
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.safety_modules.reward_constrained_safety_module import (
    RewardConstrainedSafetyModule,
)
from pearl.safety_modules.risk_sensitive_safety_modules import RiskSensitiveSafetyModule
from pearl.safety_modules.safety_module import SafetyModule


def pearl_agent_compatibility_check(
    policy_learner: PolicyLearner,
    safety_module: SafetyModule,
    replay_buffer: ReplayBuffer,
) -> None:
    """
    Check if different modules of the Pearl agent are compatible with each other.
    """
    if isinstance(policy_learner, DistributionalPolicyLearner):
        if not isinstance(safety_module, RiskSensitiveSafetyModule):
            raise TypeError(
                "A distributional policy learner requires a risk-sensitive safety module."
            )

    if isinstance(safety_module, RewardConstrainedSafetyModule):
        if not isinstance(policy_learner, RCTD3):
            raise TypeError(
                "An Reward Constrained Policy Optimization safety module requires RCTD3 policy learner."
            )

Functions

def pearl_agent_compatibility_check(policy_learner: PolicyLearner, safety_module: SafetyModule, replay_buffer: ReplayBuffer) ‑> None

Check if different modules of the Pearl agent are compatible with each other.

Expand source code
def pearl_agent_compatibility_check(
    policy_learner: PolicyLearner,
    safety_module: SafetyModule,
    replay_buffer: ReplayBuffer,
) -> None:
    """
    Check if different modules of the Pearl agent are compatible with each other.
    """
    if isinstance(policy_learner, DistributionalPolicyLearner):
        if not isinstance(safety_module, RiskSensitiveSafetyModule):
            raise TypeError(
                "A distributional policy learner requires a risk-sensitive safety module."
            )

    if isinstance(safety_module, RewardConstrainedSafetyModule):
        if not isinstance(policy_learner, RCTD3):
            raise TypeError(
                "An Reward Constrained Policy Optimization safety module requires RCTD3 policy learner."
            )