Module pearl.pearl_agent
Expand source code
import typing
from typing import Any, Dict, Optional
import torch
from pearl.api.action import Action
from pearl.api.action_result import ActionResult
from pearl.api.action_space import ActionSpace
from pearl.api.agent import Agent
from pearl.api.observation import Observation
from pearl.api.state import SubjectiveState
from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
)
from pearl.history_summarization_modules.identity_history_summarization_module import (
IdentityHistorySummarizationModule,
)
from pearl.policy_learners.policy_learner import (
DistributionalPolicyLearner,
PolicyLearner,
)
from pearl.replay_buffers.examples.single_transition_replay_buffer import (
SingleTransitionReplayBuffer,
)
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.transition import TransitionBatch
from pearl.safety_modules.identity_safety_module import IdentitySafetyModule
from pearl.safety_modules.risk_sensitive_safety_modules import RiskNeutralSafetyModule
from pearl.safety_modules.safety_module import SafetyModule
from pearl.utils.compatibility_checks import pearl_agent_compatibility_check
from pearl.utils.device import get_pearl_device
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
class PearlAgent(Agent):
"""
A Agent gathering the most common aspects of production-ready agents.
It is meant as a catch-all agent whose functionality is defined by flags
(and possibly factories down the line)
"""
default_safety_module_type = IdentitySafetyModule
default_risk_sensitive_safety_module_type = RiskNeutralSafetyModule
default_history_summarization_module_type = IdentityHistorySummarizationModule
default_replay_buffer_type = SingleTransitionReplayBuffer
# TODO: define a data structure that hosts the configs for a Pearl Agent
def __init__(
self,
policy_learner: PolicyLearner,
safety_module: Optional[SafetyModule] = None,
replay_buffer: Optional[ReplayBuffer] = None,
history_summarization_module: Optional[HistorySummarizationModule] = None,
device_id: int = -1,
) -> None:
"""
Initializes the PearlAgent.
Args:
policy_learner (PolicyLearner): An instance of PolicyLearner.
safety_module (SafetyModule, optional): An instance of SafetyModule. Defaults to
RiskNeutralSafetyModule for distributional policy learner types, and
IdentitySafetyModule for other types.
risk_sensitive_safety_module (RiskSensitiveSafetyModule, optional): An instance of
RiskSensitiveSafetyModule. Defaults to RiskNeutralSafetyModule.
history_summarization_module (HistorySummarizationModule, optional): An instance of
HistorySummarizationModule. Defaults to IdentityHistorySummarizationModule.
replay_buffer (ReplayBuffer, optional): A replay buffer. Defaults to a single-transition
replay buffer (note: this default is likely to change).
"""
self.policy_learner: PolicyLearner = policy_learner
self._device_id: int = device_id
self.device: torch.device = get_pearl_device(device_id)
self.safety_module: SafetyModule = (
safety_module
if safety_module is not None
else (
PearlAgent.default_risk_sensitive_safety_module_type()
if isinstance(self.policy_learner, DistributionalPolicyLearner)
else PearlAgent.default_safety_module_type()
)
)
# adds the safety module to the policy learner as well
# @jalaj, we need to follow the practice below for safety module
self.policy_learner.safety_module = self.safety_module
self.replay_buffer: ReplayBuffer = (
PearlAgent.default_replay_buffer_type()
if replay_buffer is None
else replay_buffer
)
self.history_summarization_module: HistorySummarizationModule = (
PearlAgent.default_history_summarization_module_type()
if history_summarization_module is None
else history_summarization_module
)
self.policy_learner.set_history_summarization_module(
self.history_summarization_module
)
# set here so replay_buffer and policy_learner are in sync
self.replay_buffer.is_action_continuous = (
self.policy_learner.is_action_continuous
)
self.replay_buffer.device = self.device
# check that all components of the agent are compatible with each other
pearl_agent_compatibility_check(
self.policy_learner, self.safety_module, self.replay_buffer
)
self._subjective_state: Optional[SubjectiveState] = None
self._latest_action: Optional[Action] = None
self._action_space: Optional[ActionSpace] = None
self.policy_learner.to(self.device)
self.history_summarization_module.to(self.device)
def act(self, exploit: bool = False) -> Action:
# We need to adapt that to use Tensors, or instead revert equalling
# SubjectiveState to None.
assert self._action_space is not None
safe_action_space = self.safety_module.filter_action(
# pyre-fixme[6]: contextual bandit environments use subjective state None
self._subjective_state,
self._action_space,
)
# PolicyLearner requires all tensor inputs to be already on the correct device
# before being passed to it.
subjective_state_to_be_used = (
torch.as_tensor(self._subjective_state).to(self.device)
if self.policy_learner.requires_tensors # temporary fix before abstract interfaces
else self._subjective_state
)
# TODO: The following code is too specific to be at this high-level.
# This needs to be moved to a better place.
if (
isinstance(safe_action_space, DiscreteActionSpace)
and self.policy_learner.requires_tensors
):
assert isinstance(safe_action_space, DiscreteActionSpace)
safe_action_space.to(self.device)
action = self.policy_learner.act(
subjective_state_to_be_used, safe_action_space, exploit=exploit # pyre-fixme[6]
)
if isinstance(safe_action_space, DiscreteActionSpace):
self._latest_action = safe_action_space.actions_batch[int(action.item())]
else:
self._latest_action = action
return action
def observe(
self,
action_result: ActionResult,
) -> None:
current_history = self.history_summarization_module.get_history()
new_subjective_state = self._update_subjective_state(action_result.observation)
new_history = self.history_summarization_module.get_history()
# TODO: define each push with a uuid
# TODO: currently assumes the same action space across all steps
# need to modify ActionResults
assert self._latest_action is not None
assert self._action_space is not None
self.replay_buffer.push(
# pyre-fixme[6]: this can be removed when tabular Q learning test uses tensors
state=current_history,
action=self._latest_action,
reward=action_result.reward,
# pyre-fixme[6]: this can be removed when tabular Q learning test uses tensors
next_state=new_history,
curr_available_actions=self._action_space, # curr_available_actions
next_available_actions=self._action_space
if action_result.available_action_space is None
else action_result.available_action_space, # next_available_actions
done=action_result.done,
max_number_actions=self.policy_learner.action_representation_module.max_number_actions
if not self.policy_learner.is_action_continuous
else None, # max number of actions for discrete action space
cost=action_result.cost,
)
self._action_space = (
action_result.available_action_space
if action_result.available_action_space is not None
else self._action_space
)
self._subjective_state = new_subjective_state
def learn(self) -> Dict[str, Any]:
report = self.policy_learner.learn(self.replay_buffer)
self.safety_module.learn(self.replay_buffer, self.policy_learner)
if self.policy_learner.on_policy:
self.replay_buffer.clear()
return report
def learn_batch(self, batch: TransitionBatch) -> Dict[str, typing.Any]:
"""
This API is often used in offline learning
where users pass in a batch of data to train directly
"""
batch = self.policy_learner.preprocess_batch(batch)
policy_learner_loss = self.policy_learner.learn_batch(batch)
self.safety_module.learn_batch(batch)
return policy_learner_loss
def reset(
self, observation: Observation, available_action_space: ActionSpace
) -> None:
self.history_summarization_module.reset()
self.history_summarization_module.to(self.device)
self._latest_action = None
self._subjective_state = self._update_subjective_state(observation)
self._action_space = available_action_space
self.policy_learner.reset(available_action_space)
def _update_subjective_state(
self, observation: Observation
) -> Optional[SubjectiveState]:
if observation is None:
return None
latest_action_representation = None
if self._latest_action is not None:
latest_action_representation = (
self.policy_learner.action_representation_module(
torch.as_tensor(self._latest_action).unsqueeze(0).to(self.device)
)
)
observation_to_be_used = (
torch.as_tensor(observation).to(self.device)
if self.policy_learner.requires_tensors # temporary fix before abstract interfaces
else observation
)
return self.history_summarization_module.summarize_history(
observation_to_be_used, latest_action_representation
)
def __str__(self) -> str:
items = []
items.append(self.policy_learner)
if type(self.safety_module) is not PearlAgent.default_safety_module_type:
items.append(self.safety_module)
if (
type(self.history_summarization_module)
is not PearlAgent.default_history_summarization_module_type
):
items.append(self.history_summarization_module)
if type(self.replay_buffer) is not PearlAgent.default_replay_buffer_type:
items.append(self.replay_buffer)
return "PearlAgent" + (
" with " + ", ".join(str(item) for item in items) if items else ""
)
Classes
class PearlAgent (policy_learner: PolicyLearner, safety_module: Optional[SafetyModule] = None, replay_buffer: Optional[ReplayBuffer] = None, history_summarization_module: Optional[HistorySummarizationModule] = None, device_id: int = -1)
-
A Agent gathering the most common aspects of production-ready agents. It is meant as a catch-all agent whose functionality is defined by flags (and possibly factories down the line)
Initializes the PearlAgent.
Args
policy_learner
:PolicyLearner
- An instance of PolicyLearner.
safety_module
:SafetyModule
, optional- An instance of SafetyModule. Defaults to RiskNeutralSafetyModule for distributional policy learner types, and IdentitySafetyModule for other types.
risk_sensitive_safety_module
:RiskSensitiveSafetyModule
, optional- An instance of RiskSensitiveSafetyModule. Defaults to RiskNeutralSafetyModule.
history_summarization_module
:HistorySummarizationModule
, optional- An instance of HistorySummarizationModule. Defaults to IdentityHistorySummarizationModule.
replay_buffer
:ReplayBuffer
, optional- A replay buffer. Defaults to a single-transition replay buffer (note: this default is likely to change).
Expand source code
class PearlAgent(Agent): """ A Agent gathering the most common aspects of production-ready agents. It is meant as a catch-all agent whose functionality is defined by flags (and possibly factories down the line) """ default_safety_module_type = IdentitySafetyModule default_risk_sensitive_safety_module_type = RiskNeutralSafetyModule default_history_summarization_module_type = IdentityHistorySummarizationModule default_replay_buffer_type = SingleTransitionReplayBuffer # TODO: define a data structure that hosts the configs for a Pearl Agent def __init__( self, policy_learner: PolicyLearner, safety_module: Optional[SafetyModule] = None, replay_buffer: Optional[ReplayBuffer] = None, history_summarization_module: Optional[HistorySummarizationModule] = None, device_id: int = -1, ) -> None: """ Initializes the PearlAgent. Args: policy_learner (PolicyLearner): An instance of PolicyLearner. safety_module (SafetyModule, optional): An instance of SafetyModule. Defaults to RiskNeutralSafetyModule for distributional policy learner types, and IdentitySafetyModule for other types. risk_sensitive_safety_module (RiskSensitiveSafetyModule, optional): An instance of RiskSensitiveSafetyModule. Defaults to RiskNeutralSafetyModule. history_summarization_module (HistorySummarizationModule, optional): An instance of HistorySummarizationModule. Defaults to IdentityHistorySummarizationModule. replay_buffer (ReplayBuffer, optional): A replay buffer. Defaults to a single-transition replay buffer (note: this default is likely to change). """ self.policy_learner: PolicyLearner = policy_learner self._device_id: int = device_id self.device: torch.device = get_pearl_device(device_id) self.safety_module: SafetyModule = ( safety_module if safety_module is not None else ( PearlAgent.default_risk_sensitive_safety_module_type() if isinstance(self.policy_learner, DistributionalPolicyLearner) else PearlAgent.default_safety_module_type() ) ) # adds the safety module to the policy learner as well # @jalaj, we need to follow the practice below for safety module self.policy_learner.safety_module = self.safety_module self.replay_buffer: ReplayBuffer = ( PearlAgent.default_replay_buffer_type() if replay_buffer is None else replay_buffer ) self.history_summarization_module: HistorySummarizationModule = ( PearlAgent.default_history_summarization_module_type() if history_summarization_module is None else history_summarization_module ) self.policy_learner.set_history_summarization_module( self.history_summarization_module ) # set here so replay_buffer and policy_learner are in sync self.replay_buffer.is_action_continuous = ( self.policy_learner.is_action_continuous ) self.replay_buffer.device = self.device # check that all components of the agent are compatible with each other pearl_agent_compatibility_check( self.policy_learner, self.safety_module, self.replay_buffer ) self._subjective_state: Optional[SubjectiveState] = None self._latest_action: Optional[Action] = None self._action_space: Optional[ActionSpace] = None self.policy_learner.to(self.device) self.history_summarization_module.to(self.device) def act(self, exploit: bool = False) -> Action: # We need to adapt that to use Tensors, or instead revert equalling # SubjectiveState to None. assert self._action_space is not None safe_action_space = self.safety_module.filter_action( # pyre-fixme[6]: contextual bandit environments use subjective state None self._subjective_state, self._action_space, ) # PolicyLearner requires all tensor inputs to be already on the correct device # before being passed to it. subjective_state_to_be_used = ( torch.as_tensor(self._subjective_state).to(self.device) if self.policy_learner.requires_tensors # temporary fix before abstract interfaces else self._subjective_state ) # TODO: The following code is too specific to be at this high-level. # This needs to be moved to a better place. if ( isinstance(safe_action_space, DiscreteActionSpace) and self.policy_learner.requires_tensors ): assert isinstance(safe_action_space, DiscreteActionSpace) safe_action_space.to(self.device) action = self.policy_learner.act( subjective_state_to_be_used, safe_action_space, exploit=exploit # pyre-fixme[6] ) if isinstance(safe_action_space, DiscreteActionSpace): self._latest_action = safe_action_space.actions_batch[int(action.item())] else: self._latest_action = action return action def observe( self, action_result: ActionResult, ) -> None: current_history = self.history_summarization_module.get_history() new_subjective_state = self._update_subjective_state(action_result.observation) new_history = self.history_summarization_module.get_history() # TODO: define each push with a uuid # TODO: currently assumes the same action space across all steps # need to modify ActionResults assert self._latest_action is not None assert self._action_space is not None self.replay_buffer.push( # pyre-fixme[6]: this can be removed when tabular Q learning test uses tensors state=current_history, action=self._latest_action, reward=action_result.reward, # pyre-fixme[6]: this can be removed when tabular Q learning test uses tensors next_state=new_history, curr_available_actions=self._action_space, # curr_available_actions next_available_actions=self._action_space if action_result.available_action_space is None else action_result.available_action_space, # next_available_actions done=action_result.done, max_number_actions=self.policy_learner.action_representation_module.max_number_actions if not self.policy_learner.is_action_continuous else None, # max number of actions for discrete action space cost=action_result.cost, ) self._action_space = ( action_result.available_action_space if action_result.available_action_space is not None else self._action_space ) self._subjective_state = new_subjective_state def learn(self) -> Dict[str, Any]: report = self.policy_learner.learn(self.replay_buffer) self.safety_module.learn(self.replay_buffer, self.policy_learner) if self.policy_learner.on_policy: self.replay_buffer.clear() return report def learn_batch(self, batch: TransitionBatch) -> Dict[str, typing.Any]: """ This API is often used in offline learning where users pass in a batch of data to train directly """ batch = self.policy_learner.preprocess_batch(batch) policy_learner_loss = self.policy_learner.learn_batch(batch) self.safety_module.learn_batch(batch) return policy_learner_loss def reset( self, observation: Observation, available_action_space: ActionSpace ) -> None: self.history_summarization_module.reset() self.history_summarization_module.to(self.device) self._latest_action = None self._subjective_state = self._update_subjective_state(observation) self._action_space = available_action_space self.policy_learner.reset(available_action_space) def _update_subjective_state( self, observation: Observation ) -> Optional[SubjectiveState]: if observation is None: return None latest_action_representation = None if self._latest_action is not None: latest_action_representation = ( self.policy_learner.action_representation_module( torch.as_tensor(self._latest_action).unsqueeze(0).to(self.device) ) ) observation_to_be_used = ( torch.as_tensor(observation).to(self.device) if self.policy_learner.requires_tensors # temporary fix before abstract interfaces else observation ) return self.history_summarization_module.summarize_history( observation_to_be_used, latest_action_representation ) def __str__(self) -> str: items = [] items.append(self.policy_learner) if type(self.safety_module) is not PearlAgent.default_safety_module_type: items.append(self.safety_module) if ( type(self.history_summarization_module) is not PearlAgent.default_history_summarization_module_type ): items.append(self.history_summarization_module) if type(self.replay_buffer) is not PearlAgent.default_replay_buffer_type: items.append(self.replay_buffer) return "PearlAgent" + ( " with " + ", ".join(str(item) for item in items) if items else "" )
Ancestors
- Agent
- abc.ABC
Class variables
var default_history_summarization_module_type
-
A history summarization module that simply uses the original observations.
var default_replay_buffer_type
-
Helper class that provides a standard way to create an ABC using inheritance.
var default_risk_sensitive_safety_module_type
-
A safety module that computes q values as expectation of a q value distribution.
var default_safety_module_type
-
A safety module that does not restrict action spaces.
Methods
def act(self, exploit: bool = False) ‑> torch.Tensor
-
Expand source code
def act(self, exploit: bool = False) -> Action: # We need to adapt that to use Tensors, or instead revert equalling # SubjectiveState to None. assert self._action_space is not None safe_action_space = self.safety_module.filter_action( # pyre-fixme[6]: contextual bandit environments use subjective state None self._subjective_state, self._action_space, ) # PolicyLearner requires all tensor inputs to be already on the correct device # before being passed to it. subjective_state_to_be_used = ( torch.as_tensor(self._subjective_state).to(self.device) if self.policy_learner.requires_tensors # temporary fix before abstract interfaces else self._subjective_state ) # TODO: The following code is too specific to be at this high-level. # This needs to be moved to a better place. if ( isinstance(safe_action_space, DiscreteActionSpace) and self.policy_learner.requires_tensors ): assert isinstance(safe_action_space, DiscreteActionSpace) safe_action_space.to(self.device) action = self.policy_learner.act( subjective_state_to_be_used, safe_action_space, exploit=exploit # pyre-fixme[6] ) if isinstance(safe_action_space, DiscreteActionSpace): self._latest_action = safe_action_space.actions_batch[int(action.item())] else: self._latest_action = action return action
def learn(self) ‑> Dict[str, Any]
-
Expand source code
def learn(self) -> Dict[str, Any]: report = self.policy_learner.learn(self.replay_buffer) self.safety_module.learn(self.replay_buffer, self.policy_learner) if self.policy_learner.on_policy: self.replay_buffer.clear() return report
def learn_batch(self, batch: TransitionBatch) ‑> Dict[str, Any]
-
This API is often used in offline learning where users pass in a batch of data to train directly
Expand source code
def learn_batch(self, batch: TransitionBatch) -> Dict[str, typing.Any]: """ This API is often used in offline learning where users pass in a batch of data to train directly """ batch = self.policy_learner.preprocess_batch(batch) policy_learner_loss = self.policy_learner.learn_batch(batch) self.safety_module.learn_batch(batch) return policy_learner_loss
def observe(self, action_result: ActionResult) ‑> None
-
Expand source code
def observe( self, action_result: ActionResult, ) -> None: current_history = self.history_summarization_module.get_history() new_subjective_state = self._update_subjective_state(action_result.observation) new_history = self.history_summarization_module.get_history() # TODO: define each push with a uuid # TODO: currently assumes the same action space across all steps # need to modify ActionResults assert self._latest_action is not None assert self._action_space is not None self.replay_buffer.push( # pyre-fixme[6]: this can be removed when tabular Q learning test uses tensors state=current_history, action=self._latest_action, reward=action_result.reward, # pyre-fixme[6]: this can be removed when tabular Q learning test uses tensors next_state=new_history, curr_available_actions=self._action_space, # curr_available_actions next_available_actions=self._action_space if action_result.available_action_space is None else action_result.available_action_space, # next_available_actions done=action_result.done, max_number_actions=self.policy_learner.action_representation_module.max_number_actions if not self.policy_learner.is_action_continuous else None, # max number of actions for discrete action space cost=action_result.cost, ) self._action_space = ( action_result.available_action_space if action_result.available_action_space is not None else self._action_space ) self._subjective_state = new_subjective_state
def reset(self, observation: object, available_action_space: ActionSpace) ‑> None
-
Expand source code
def reset( self, observation: Observation, available_action_space: ActionSpace ) -> None: self.history_summarization_module.reset() self.history_summarization_module.to(self.device) self._latest_action = None self._subjective_state = self._update_subjective_state(observation) self._action_space = available_action_space self.policy_learner.reset(available_action_space)