Module pearl.utils.functional_utils.experimentation.create_offline_data
Expand source code
#!/usr/bin/env fbpython
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import pickle
from collections import deque
from typing import List, Optional
import torch
from pearl.api.environment import Environment
from pearl.api.reward import Value
from pearl.pearl_agent import PearlAgent
from pearl.utils.functional_utils.train_and_eval.online_learning import run_episode
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
def create_offline_data(
agent: PearlAgent,
env: Environment,
save_path: str,
file_name: str,
max_len_offline_data: int = 50000,
learn: bool = True,
learn_after_episode: bool = True,
evaluation_episodes: int = 100,
seed: Optional[int] = None,
) -> List[Value]:
"""
This function creates offline data by interacting with a given environment using a specified
Pearl Agent. This is mostly for illustration with standard benchmark environments only.
For most practical use cases, offline data collection will use custom pipelines.
Args:
agent: a pearl agent with policy learner, exploration module and replay buffer specified
(e.g. a DQN agent).
env: an environment to collect data from (e.g. GymEnvironment)
number_of_episodes: number of episodes for which data is to be collected.
learn: whether to learn after each episode (depends on the policy learner used by agent).
exploit: set as default to False as we want exploration during data collection.
learn_after_episode: whether to learn after each episode
(depends on the policy learner used by agent).
"""
# much of this function overlaps with episode return function but i think writing it
# like this is cleaner
print(f"collecting data from env: {env} using agent: {agent}")
epi_returns = []
epi = 0
raw_transitions_buffer = deque([], maxlen=max_len_offline_data)
while len(raw_transitions_buffer) < max_len_offline_data:
g = 0
observation, action_space = env.reset(seed=seed)
agent.reset(observation, action_space)
done = False
while not done:
action = agent.act(
exploit=False
) # exploit is explicitly set to False as we want exploration during data collection.
action_result = env.step(action)
g += action_result.reward
agent.observe(action_result)
transition_tuple = {
"observation": observation,
"action": action,
"reward": action_result.reward,
"next_observation": action_result.observation,
"curr_available_actions": env.action_space,
"next_available_actions": env.action_space,
"done": action_result.done,
"max_number_actions": env.action_space.n
if isinstance(env.action_space, DiscreteActionSpace)
else None,
}
observation = action_result.observation
raw_transitions_buffer.append(transition_tuple)
if learn and not learn_after_episode:
agent.learn()
done = action_result.done
if learn and learn_after_episode:
agent.learn()
epi_returns.append(g)
print(f"\rEpisode {epi}, return={g}", end="")
epi += 1
# save offline transition tuples in a .pt file
torch.save(raw_transitions_buffer, save_path + file_name)
# save training returns of the data collection agent
with open(
save_path
+ "training_returns_data_collection_agent_"
+ str(max_len_offline_data)
+ ".pickle",
"wb",
) as handle:
# @lint-ignore PYTHONPICKLEISBAD
pickle.dump(epi_returns, handle, protocol=pickle.HIGHEST_PROTOCOL)
# evaluation results of the data collection agent
print(" ")
print(
"data collection complete; starting evaluation runs for data collection agent"
)
evaluation_returns = []
for i in range(evaluation_episodes):
# data creation and evaluation seed should be different
evaluation_seed = seed + i if seed is not None else seed
episode_info, _ = run_episode(
agent=agent,
env=env,
learn=False,
exploit=True,
learn_after_episode=False,
seed=evaluation_seed,
)
g = episode_info["return"]
print(f"\repisode {i}, return={g}", end="")
evaluation_returns.append(g)
with open(
save_path
+ "evaluation_returns_data_collection_agent_"
+ str(max_len_offline_data)
+ ".pickle",
"wb",
) as handle:
# @lint-ignore PYTHONPICKLEISBAD
pickle.dump(evaluation_returns, handle, protocol=pickle.HIGHEST_PROTOCOL)
return epi_returns # for plotting returns of the policy used to collect offine data
# getting returns of the data collection agent, either from file or by stitching trajectories
# in the training data
def get_data_collection_agent_returns(
data_path: str,
returns_file_path: Optional[str] = None,
) -> List[Value]:
"""
This function returns episode returns of a Pearl Agent using for offline data collection.
The returns file can be directly provided or we can stitch together trajectories in the offline
data. This function is used to compute normalized scores for offline rl benchmarks.
Args:
data_path: path to the directory where the offline data is stored.
returns_file_path: path to the file containing returns of the data collection agent.
"""
print("getting returns of the data collection agent agent")
if returns_file_path is None:
print(
f"using offline training data in {data_path} to stitch trajectories and compute returns"
)
with open(data_path, "rb") as file:
data = torch.load(file, map_location=torch.device("cpu"))
data_collection_agent_returns = []
g = 0
for transition in list(data):
if transition["done"]:
data_collection_agent_returns.append(g)
g = 0
else:
g += transition["reward"]
else:
print(f"loading returns from file {returns_file_path}")
with open(returns_file_path, "rb") as file:
# @lint-ignore PYTHONPICKLEISBAD
data_collection_agent_returns = pickle.load(file)
return data_collection_agent_returns
Functions
def create_offline_data(agent: PearlAgent, env: Environment, save_path: str, file_name: str, max_len_offline_data: int = 50000, learn: bool = True, learn_after_episode: bool = True, evaluation_episodes: int = 100, seed: Optional[int] = None) ‑> List[object]
-
This function creates offline data by interacting with a given environment using a specified Pearl Agent. This is mostly for illustration with standard benchmark environments only. For most practical use cases, offline data collection will use custom pipelines.
Args
agent
- a pearl agent with policy learner, exploration module and replay buffer specified (e.g. a DQN agent).
env
- an environment to collect data from (e.g. GymEnvironment)
number_of_episodes
- number of episodes for which data is to be collected.
learn
- whether to learn after each episode (depends on the policy learner used by agent).
exploit
- set as default to False as we want exploration during data collection.
learn_after_episode
- whether to learn after each episode
(depends on the policy learner used by agent).
Expand source code
def create_offline_data( agent: PearlAgent, env: Environment, save_path: str, file_name: str, max_len_offline_data: int = 50000, learn: bool = True, learn_after_episode: bool = True, evaluation_episodes: int = 100, seed: Optional[int] = None, ) -> List[Value]: """ This function creates offline data by interacting with a given environment using a specified Pearl Agent. This is mostly for illustration with standard benchmark environments only. For most practical use cases, offline data collection will use custom pipelines. Args: agent: a pearl agent with policy learner, exploration module and replay buffer specified (e.g. a DQN agent). env: an environment to collect data from (e.g. GymEnvironment) number_of_episodes: number of episodes for which data is to be collected. learn: whether to learn after each episode (depends on the policy learner used by agent). exploit: set as default to False as we want exploration during data collection. learn_after_episode: whether to learn after each episode (depends on the policy learner used by agent). """ # much of this function overlaps with episode return function but i think writing it # like this is cleaner print(f"collecting data from env: {env} using agent: {agent}") epi_returns = [] epi = 0 raw_transitions_buffer = deque([], maxlen=max_len_offline_data) while len(raw_transitions_buffer) < max_len_offline_data: g = 0 observation, action_space = env.reset(seed=seed) agent.reset(observation, action_space) done = False while not done: action = agent.act( exploit=False ) # exploit is explicitly set to False as we want exploration during data collection. action_result = env.step(action) g += action_result.reward agent.observe(action_result) transition_tuple = { "observation": observation, "action": action, "reward": action_result.reward, "next_observation": action_result.observation, "curr_available_actions": env.action_space, "next_available_actions": env.action_space, "done": action_result.done, "max_number_actions": env.action_space.n if isinstance(env.action_space, DiscreteActionSpace) else None, } observation = action_result.observation raw_transitions_buffer.append(transition_tuple) if learn and not learn_after_episode: agent.learn() done = action_result.done if learn and learn_after_episode: agent.learn() epi_returns.append(g) print(f"\rEpisode {epi}, return={g}", end="") epi += 1 # save offline transition tuples in a .pt file torch.save(raw_transitions_buffer, save_path + file_name) # save training returns of the data collection agent with open( save_path + "training_returns_data_collection_agent_" + str(max_len_offline_data) + ".pickle", "wb", ) as handle: # @lint-ignore PYTHONPICKLEISBAD pickle.dump(epi_returns, handle, protocol=pickle.HIGHEST_PROTOCOL) # evaluation results of the data collection agent print(" ") print( "data collection complete; starting evaluation runs for data collection agent" ) evaluation_returns = [] for i in range(evaluation_episodes): # data creation and evaluation seed should be different evaluation_seed = seed + i if seed is not None else seed episode_info, _ = run_episode( agent=agent, env=env, learn=False, exploit=True, learn_after_episode=False, seed=evaluation_seed, ) g = episode_info["return"] print(f"\repisode {i}, return={g}", end="") evaluation_returns.append(g) with open( save_path + "evaluation_returns_data_collection_agent_" + str(max_len_offline_data) + ".pickle", "wb", ) as handle: # @lint-ignore PYTHONPICKLEISBAD pickle.dump(evaluation_returns, handle, protocol=pickle.HIGHEST_PROTOCOL) return epi_returns # for plotting returns of the policy used to collect offine data
def get_data_collection_agent_returns(data_path: str, returns_file_path: Optional[str] = None) ‑> List[object]
-
This function returns episode returns of a Pearl Agent using for offline data collection. The returns file can be directly provided or we can stitch together trajectories in the offline data. This function is used to compute normalized scores for offline rl benchmarks.
Args
data_path
- path to the directory where the offline data is stored.
returns_file_path
- path to the file containing returns of the data collection agent.
Expand source code
def get_data_collection_agent_returns( data_path: str, returns_file_path: Optional[str] = None, ) -> List[Value]: """ This function returns episode returns of a Pearl Agent using for offline data collection. The returns file can be directly provided or we can stitch together trajectories in the offline data. This function is used to compute normalized scores for offline rl benchmarks. Args: data_path: path to the directory where the offline data is stored. returns_file_path: path to the file containing returns of the data collection agent. """ print("getting returns of the data collection agent agent") if returns_file_path is None: print( f"using offline training data in {data_path} to stitch trajectories and compute returns" ) with open(data_path, "rb") as file: data = torch.load(file, map_location=torch.device("cpu")) data_collection_agent_returns = [] g = 0 for transition in list(data): if transition["done"]: data_collection_agent_returns.append(g) g = 0 else: g += transition["reward"] else: print(f"loading returns from file {returns_file_path}") with open(returns_file_path, "rb") as file: # @lint-ignore PYTHONPICKLEISBAD data_collection_agent_returns = pickle.load(file) return data_collection_agent_returns