Module pearl.utils.functional_utils.train_and_eval.offline_learning_and_evaluation

Expand source code
import io
import logging
import os
import sys
from logging import Logger
from typing import List, Optional

import requests
import torch

try:
    from libfb.py.certpathpicker.cert_path_picker import get_client_credential_paths
except ImportError:
    print(
        "Cert path picker failed to find FB proxy certificates. Downloading data from "
        "a url directly won't work inside meta"
    )

from pearl.api.environment import Environment
from pearl.pearl_agent import PearlAgent
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import (
    FIFOOffPolicyReplayBuffer,
)
from pearl.replay_buffers.transition import TransitionBatch
from pearl.utils.functional_utils.experimentation.set_seed import set_seed
from pearl.utils.functional_utils.train_and_eval.online_learning import run_episode
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace

FWDPROXY_PORT = 8082
FWDPROXY_HOSTNAME = "https://fwdproxy"
FWDPROXY_CORP_HOSTNAME = "https://fwdproxy-regional-corp.{0}.fbinfra.net"
EXTERNAL_ENDPOINT = "https://www.google.com"
CORP_ENDPOINT = "https://npm.thefacebook.com/"

FB_CA_BUNDLE = "/var/facebook/rootcanal/ca.pem"

logger: Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))


def is_file_readable(file_path: str) -> bool:
    return os.path.isfile(file_path) and os.access(file_path, os.R_OK)


def get_offline_data_in_buffer(
    is_action_continuous: bool,
    url: Optional[str] = None,
    data_path: Optional[str] = None,
    size: int = 1000000,
) -> ReplayBuffer:
    """
    Fetches offline data from a url and returns a replay buffer which can be sampled
    to train the offline agent. For this implementation we use FIFOOffPolicyReplayBuffer.

    - Assumes the offline data is an iterable consisting of transition tuples
        (observation, action, reward, next_observation, curr_available_actions,
        next_available_actions, action_space_done) as dictionaries.

    - Also assumes offline data is in a .pt file; reading from a
        csv file can also be added later.

    Args:
        is_action_continuous: whether the action space is continuous or discrete.
            for continuous actions spaces, we need to set this flag; see 'push' method
            in FIFOOffPolicyReplayBuffer class
        url: from where offline data needs to be fetched from
        data_path: local path to the offline data
        size: size of the replay buffer

    Returns:
        ReplayBuffer: a FIFOOffPolicyReplayBuffer containing offline data of transition tuples.
    """
    if url is not None:
        thrift_cert, thrift_key = get_client_credential_paths()
        if not is_file_readable(thrift_cert) or not is_file_readable(thrift_key):
            raise RuntimeError("Missing key TLS cert settings.")

        fwdproxy_url = f"{FWDPROXY_HOSTNAME}:{FWDPROXY_PORT}"
        proxies = {"http": fwdproxy_url, "https": fwdproxy_url}
        client_cert = (thrift_cert, thrift_key)

        offline_transitions_data = requests.get(
            url, proxies=proxies, verify=FB_CA_BUNDLE, cert=client_cert
        )
        stream = io.BytesIO(offline_transitions_data.content)  # implements seek()
        raw_transitions_buffer = torch.load(stream)
    else:
        raw_transitions_buffer = torch.load(data_path)  # pyre-ignore

    offline_data_replay_buffer = FIFOOffPolicyReplayBuffer(size)
    if is_action_continuous:
        offline_data_replay_buffer._is_action_continuous = True

    for transition in raw_transitions_buffer:
        if transition["curr_available_actions"].__class__.__name__ == "Discrete":
            transition["curr_available_actions"] = DiscreteActionSpace(
                actions=list(
                    torch.arange(transition["curr_available_actions"].n).view(-1, 1)
                )
            )
        if transition["next_available_actions"].__class__.__name__ == "Discrete":
            transition["next_available_actions"] = DiscreteActionSpace(
                actions=list(
                    torch.arange(transition["next_available_actions"].n).view(-1, 1)
                )
            )
        if transition["action_space"].__class__.__name__ == "Discrete":
            transition["action_space"] = DiscreteActionSpace(
                actions=list(torch.arange(transition["action_space"].n).view(-1, 1))
            )

        offline_data_replay_buffer.push(
            transition["observation"],
            transition["action"],
            transition["reward"],
            transition["next_observation"],
            transition["curr_available_actions"],
            transition["next_available_actions"],
            transition["done"],
            transition["action_space"].n,
        )

    return offline_data_replay_buffer


def offline_learning(
    offline_agent: PearlAgent,
    data_buffer: ReplayBuffer,
    training_epochs: int = 1000,
    seed: int = 100,
) -> None:
    """
    Trains the offline agent using transition tuples from offline data (provided in
    the data_buffer). Must provide a replay buffer with transition tuples - please
    use the method get_offline_data_in_buffer to create an offline data buffer.

    Args:
        offline agent: a conservative learning agent (CQL or IQL).
        data_buffer: a replay buffer to sample a batch of transition data.
        training_epochs: number of training epochs for offline learning.
    """
    set_seed(seed=seed)

    # move replay buffer to device of the offline agent
    data_buffer.device = offline_agent.device

    # training loop
    for i in range(training_epochs):
        batch = data_buffer.sample(offline_agent.policy_learner.batch_size)
        assert isinstance(batch, TransitionBatch)
        loss = offline_agent.learn_batch(batch=batch)
        if i % 500 == 0:
            print("training epoch", i, "training loss", loss)


def offline_evaluation(
    offline_agent: PearlAgent,
    env: Environment,
    number_of_episodes: int = 1000,
    seed: Optional[int] = None,
) -> List[float]:
    """
    Evaluates the performance of an offline trained agent.

    Args:
        agent: the offline trained agent.
        env: the environment to evaluate the agent in
        number_of_episodes: the number of episodes to evaluate for.
    Returns:
        returns_offline_agent: a list of returns for each evaluation episode.
    """

    # check: during offline evaluation, the agent should not learn or explore.
    learn = False
    exploit = True
    learn_after_episode = False

    returns_offline_agent = []
    for i in range(number_of_episodes):
        evaluation_seed = seed + i if seed is not None else None
        episode_info, total_steps = run_episode(
            agent=offline_agent,
            env=env,
            learn=learn,
            exploit=exploit,
            learn_after_episode=learn_after_episode,
            seed=evaluation_seed,
        )
        g = episode_info["return"]
        if i % 1 == 0:
            print(f"\repisode {i}, return={g}", end="")
        returns_offline_agent.append(g)

    return returns_offline_agent

Functions

def get_offline_data_in_buffer(is_action_continuous: bool, url: Optional[str] = None, data_path: Optional[str] = None, size: int = 1000000) ‑> ReplayBuffer

Fetches offline data from a url and returns a replay buffer which can be sampled to train the offline agent. For this implementation we use FIFOOffPolicyReplayBuffer.

  • Assumes the offline data is an iterable consisting of transition tuples (observation, action, reward, next_observation, curr_available_actions, next_available_actions, action_space_done) as dictionaries.

  • Also assumes offline data is in a .pt file; reading from a csv file can also be added later.

Args

is_action_continuous
whether the action space is continuous or discrete. for continuous actions spaces, we need to set this flag; see 'push' method in FIFOOffPolicyReplayBuffer class
url
from where offline data needs to be fetched from
data_path
local path to the offline data
size
size of the replay buffer

Returns

ReplayBuffer
a FIFOOffPolicyReplayBuffer containing offline data of transition tuples.
Expand source code
def get_offline_data_in_buffer(
    is_action_continuous: bool,
    url: Optional[str] = None,
    data_path: Optional[str] = None,
    size: int = 1000000,
) -> ReplayBuffer:
    """
    Fetches offline data from a url and returns a replay buffer which can be sampled
    to train the offline agent. For this implementation we use FIFOOffPolicyReplayBuffer.

    - Assumes the offline data is an iterable consisting of transition tuples
        (observation, action, reward, next_observation, curr_available_actions,
        next_available_actions, action_space_done) as dictionaries.

    - Also assumes offline data is in a .pt file; reading from a
        csv file can also be added later.

    Args:
        is_action_continuous: whether the action space is continuous or discrete.
            for continuous actions spaces, we need to set this flag; see 'push' method
            in FIFOOffPolicyReplayBuffer class
        url: from where offline data needs to be fetched from
        data_path: local path to the offline data
        size: size of the replay buffer

    Returns:
        ReplayBuffer: a FIFOOffPolicyReplayBuffer containing offline data of transition tuples.
    """
    if url is not None:
        thrift_cert, thrift_key = get_client_credential_paths()
        if not is_file_readable(thrift_cert) or not is_file_readable(thrift_key):
            raise RuntimeError("Missing key TLS cert settings.")

        fwdproxy_url = f"{FWDPROXY_HOSTNAME}:{FWDPROXY_PORT}"
        proxies = {"http": fwdproxy_url, "https": fwdproxy_url}
        client_cert = (thrift_cert, thrift_key)

        offline_transitions_data = requests.get(
            url, proxies=proxies, verify=FB_CA_BUNDLE, cert=client_cert
        )
        stream = io.BytesIO(offline_transitions_data.content)  # implements seek()
        raw_transitions_buffer = torch.load(stream)
    else:
        raw_transitions_buffer = torch.load(data_path)  # pyre-ignore

    offline_data_replay_buffer = FIFOOffPolicyReplayBuffer(size)
    if is_action_continuous:
        offline_data_replay_buffer._is_action_continuous = True

    for transition in raw_transitions_buffer:
        if transition["curr_available_actions"].__class__.__name__ == "Discrete":
            transition["curr_available_actions"] = DiscreteActionSpace(
                actions=list(
                    torch.arange(transition["curr_available_actions"].n).view(-1, 1)
                )
            )
        if transition["next_available_actions"].__class__.__name__ == "Discrete":
            transition["next_available_actions"] = DiscreteActionSpace(
                actions=list(
                    torch.arange(transition["next_available_actions"].n).view(-1, 1)
                )
            )
        if transition["action_space"].__class__.__name__ == "Discrete":
            transition["action_space"] = DiscreteActionSpace(
                actions=list(torch.arange(transition["action_space"].n).view(-1, 1))
            )

        offline_data_replay_buffer.push(
            transition["observation"],
            transition["action"],
            transition["reward"],
            transition["next_observation"],
            transition["curr_available_actions"],
            transition["next_available_actions"],
            transition["done"],
            transition["action_space"].n,
        )

    return offline_data_replay_buffer
def is_file_readable(file_path: str) ‑> bool
Expand source code
def is_file_readable(file_path: str) -> bool:
    return os.path.isfile(file_path) and os.access(file_path, os.R_OK)
def offline_evaluation(offline_agent: PearlAgent, env: Environment, number_of_episodes: int = 1000, seed: Optional[int] = None) ‑> List[float]

Evaluates the performance of an offline trained agent.

Args

agent
the offline trained agent.
env
the environment to evaluate the agent in
number_of_episodes
the number of episodes to evaluate for.

Returns

returns_offline_agent
a list of returns for each evaluation episode.
Expand source code
def offline_evaluation(
    offline_agent: PearlAgent,
    env: Environment,
    number_of_episodes: int = 1000,
    seed: Optional[int] = None,
) -> List[float]:
    """
    Evaluates the performance of an offline trained agent.

    Args:
        agent: the offline trained agent.
        env: the environment to evaluate the agent in
        number_of_episodes: the number of episodes to evaluate for.
    Returns:
        returns_offline_agent: a list of returns for each evaluation episode.
    """

    # check: during offline evaluation, the agent should not learn or explore.
    learn = False
    exploit = True
    learn_after_episode = False

    returns_offline_agent = []
    for i in range(number_of_episodes):
        evaluation_seed = seed + i if seed is not None else None
        episode_info, total_steps = run_episode(
            agent=offline_agent,
            env=env,
            learn=learn,
            exploit=exploit,
            learn_after_episode=learn_after_episode,
            seed=evaluation_seed,
        )
        g = episode_info["return"]
        if i % 1 == 0:
            print(f"\repisode {i}, return={g}", end="")
        returns_offline_agent.append(g)

    return returns_offline_agent
def offline_learning(offline_agent: PearlAgent, data_buffer: ReplayBuffer, training_epochs: int = 1000, seed: int = 100) ‑> None

Trains the offline agent using transition tuples from offline data (provided in the data_buffer). Must provide a replay buffer with transition tuples - please use the method get_offline_data_in_buffer to create an offline data buffer.

Args

offline agent: a conservative learning agent (CQL or IQL).
data_buffer
a replay buffer to sample a batch of transition data.
training_epochs
number of training epochs for offline learning.
Expand source code
def offline_learning(
    offline_agent: PearlAgent,
    data_buffer: ReplayBuffer,
    training_epochs: int = 1000,
    seed: int = 100,
) -> None:
    """
    Trains the offline agent using transition tuples from offline data (provided in
    the data_buffer). Must provide a replay buffer with transition tuples - please
    use the method get_offline_data_in_buffer to create an offline data buffer.

    Args:
        offline agent: a conservative learning agent (CQL or IQL).
        data_buffer: a replay buffer to sample a batch of transition data.
        training_epochs: number of training epochs for offline learning.
    """
    set_seed(seed=seed)

    # move replay buffer to device of the offline agent
    data_buffer.device = offline_agent.device

    # training loop
    for i in range(training_epochs):
        batch = data_buffer.sample(offline_agent.policy_learner.batch_size)
        assert isinstance(batch, TransitionBatch)
        loss = offline_agent.learn_batch(batch=batch)
        if i % 500 == 0:
            print("training epoch", i, "training loss", loss)