Module pearl.utils.functional_utils.train_and_eval.online_learning

Expand source code
import logging
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from pearl.api.environment import Environment
from pearl.api.reward import Value
from pearl.pearl_agent import PearlAgent
from pearl.utils.functional_utils.experimentation.plots import fontsize_for

MA_WINDOW_SIZE = 10


def latest_moving_average(data: List[Value]) -> float:
    return (
        sum(data[-MA_WINDOW_SIZE:]) * 1.0 / MA_WINDOW_SIZE  # pyre-ignore
        if len(data) >= MA_WINDOW_SIZE
        else sum(data) * 1.0 / len(data)  # pyre-ignore
    )


def online_learning_to_png_graph(
    agent: PearlAgent,
    env: Environment,
    filename: str = "returns.png",
    number_of_episodes: int = 1000,
    learn_after_episode: bool = False,
) -> None:
    """
    Runs online learning and generates a PNG graph of the returns.

    Args:
        agent (PearlAgent): the agent.
        env (Environment): the environment.
        filename (str, optional): the filename to save to. Defaults to "returns.png".
        number_of_episodes (int, optional): the number of episodes to run. Defaults to 1000.
        learn_after_episode: if we want to learn after episode or learn every step during episode
    """

    info = online_learning(
        agent=agent,
        env=env,
        number_of_episodes=number_of_episodes,
        learn_after_episode=learn_after_episode,
    )
    assert "return" in info

    if filename is not None:
        title = f"{str(agent)} on {str(env)}"
        if len(title) > 125:
            logging.warning(
                f"Figure title is very long, with {len(title)} characters: {title}"
            )
        plt.plot(info["return"])
        plt.title(title, fontsize=fontsize_for(title))
        plt.xlabel("Episode")
        plt.ylabel("Return")
        plt.savefig(filename)
        plt.close()


def online_learning(
    agent: PearlAgent,
    env: Environment,
    number_of_episodes: Optional[int] = None,
    number_of_steps: Optional[int] = None,
    learn_after_episode: bool = False,
    print_every_x_episodes: Optional[int] = None,
    print_every_x_steps: Optional[int] = None,
    seed: Optional[int] = None,
    # if number_of_episodes is used, report every record_period episodes
    # if number_of_steps is used, report every record_period steps
    # episodic stats collected within the period are averaged and then reported
    record_period: int = 1,
) -> Dict[str, Any]:
    """
    Performs online learning for a number of episodes.

    Args:
        agent (PearlAgent): the agent.
        env (Environment): the environmnent.
        number_of_episodes (int, optional): the number of episodes to run. Defaults to 1000.
        learn_after_episode (bool, optional): asks the agent to only learn after every episode.
        Defaults to False.
    """
    assert (number_of_episodes is None and number_of_steps is not None) or (
        number_of_episodes is not None and number_of_steps is None
    )
    total_steps = 0
    total_episodes = 0
    info = {}
    info_period = {}
    while True:
        if number_of_episodes is not None and total_episodes >= number_of_episodes:
            break
        if number_of_steps is not None and total_steps >= number_of_steps:
            break
        old_total_steps = total_steps
        episode_info, episode_total_steps = run_episode(
            agent,
            env,
            learn=True,
            exploit=False,
            learn_after_episode=learn_after_episode,
            total_steps=old_total_steps,
            seed=seed,
        )
        if number_of_steps is not None and episode_total_steps > record_period:
            print(
                f"An episode is longer than the report_period: episode length {episode_total_steps}"
                ", record_period {record_period}. Try using a smaller record_period."
            )
            exit(1)
        total_steps += episode_total_steps
        total_episodes += 1
        if (
            print_every_x_steps is not None
            and old_total_steps // print_every_x_steps
            < total_steps // print_every_x_steps
        ) or (
            print_every_x_episodes is not None
            and total_episodes % print_every_x_episodes == 0
        ):
            print(
                f"episode {total_episodes}, step {total_steps}, agent={agent}, env={env}",
            )
            for key in episode_info:
                print(f"{key}: {episode_info[key]}")
        for key in episode_info:
            info_period.setdefault(key, []).append(episode_info[key])
        if number_of_episodes is not None and (
            total_episodes % record_period == 0
        ):  # record average info value every report_period episodes
            for key in info_period:
                info.setdefault(key, []).append(np.mean(info_period[key]))
            info_period = {}
        if number_of_steps is not None and old_total_steps // record_period < (
            total_steps
        ) // (
            record_period
        ):  # record average info value every record_period steps
            for key in info_period:
                info.setdefault(key, []).append(np.mean(info_period[key]))
            info_period = {}
    return info


def target_return_is_reached(
    target_return: Value,
    max_episodes: int,
    agent: PearlAgent,
    env: Environment,
    learn: bool,
    learn_after_episode: bool,
    exploit: bool,
    required_target_returns_in_a_row: int = 1,
    check_moving_average: bool = False,
) -> bool:
    """
    Learns until obtaining target return (a certain number of times in a row, default 1)
    or max_episodes are completed.
    Args
        target_return (Value): the targeted return.
        max_episodes (int): the maximum number of episodes to run.
        agent (Agent): the agent.
        env (Environment): the environment.
        learn (bool): whether to learn.
        learn_after_episode (bool): whether to learn after every episode.
        exploit (bool): whether to exploit.
        required_target_returns_in_a_row (int, optional): how many times we must hit the target
        to succeed.
        check_moving_average: if this is enabled, we check the if latest moving average value
                              reaches goal
    Returns
        bool: whether target_return has been obtained required_target_returns_in_a_row times
              in a row.
    """
    target_returns_in_a_row = 0
    returns = []
    total_steps = 0
    for i in range(max_episodes):
        if i % 10 == 0 and i != 0:
            print(f"episode {i} return:", returns[-1])
        episode_info, episode_total_steps = run_episode(
            agent=agent,
            env=env,
            learn=learn,
            learn_after_episode=learn_after_episode,
            exploit=exploit,
        )
        total_steps += episode_total_steps
        returns.append(episode_info["return"])
        value = (
            episode_info["return"]
            if not check_moving_average
            else latest_moving_average(returns)
        )
        if value >= target_return:
            target_returns_in_a_row += 1
            if target_returns_in_a_row >= required_target_returns_in_a_row:
                return True
        else:
            target_returns_in_a_row = 0
    return False


def run_episode(
    agent: PearlAgent,
    env: Environment,
    learn: bool = False,
    exploit: bool = True,
    learn_after_episode: bool = False,
    total_steps: int = 0,
    seed: Optional[int] = None,
) -> Tuple[Dict[str, Any], int]:
    """
    Runs one episode and returns an info dict and number of steps taken.

    Args:
        agent (Agent): the agent.
        env (Environment): the environment.
        learn (bool, optional): Runs `agent.learn()` after every step. Defaults to False.
        exploit (bool, optional): asks the agent to only exploit. Defaults to False.
        learn_after_episode (bool, optional): asks the agent to only learn at
                                              the end of the episode. Defaults to False.

    Returns:
        Tuple[Dict[str, Any], int]: the return of the episode and the number of steps taken.
    """
    if seed is None:
        observation, action_space = env.reset(seed=seed)
    else:
        # each episode has a different seed
        observation, action_space = env.reset(seed=seed + total_steps)
    agent.reset(observation, action_space)
    cum_reward = 0
    cum_cost = 0
    done = False
    episode_steps = 0
    num_risky_sa = 0
    while not done:
        action = agent.act(exploit=exploit)
        action = (
            action.cpu() if isinstance(action, torch.Tensor) else action
        )  # action can be int sometimes
        action_result = env.step(action)
        cum_reward += action_result.reward
        if (
            num_risky_sa is not None
            and action_result.info is not None
            and "risky_sa" in action_result.info
        ):
            num_risky_sa += action_result.info["risky_sa"]
        else:
            num_risky_sa = None
        if cum_cost is not None and action_result.cost is not None:
            cum_cost += action_result.cost
        else:
            cum_cost = None
        agent.observe(action_result)
        if learn and not learn_after_episode:
            agent.learn()
        done = action_result.done
        episode_steps += 1

    if learn and learn_after_episode:
        agent.learn()

    info = {"return": cum_reward}
    if num_risky_sa is not None:
        info.update({"risky_sa_ratio": num_risky_sa / episode_steps})
    if cum_cost is not None:
        info.update({"return_cost": cum_cost})

    return info, episode_steps

Functions

def latest_moving_average(data: List[object]) ‑> float
Expand source code
def latest_moving_average(data: List[Value]) -> float:
    return (
        sum(data[-MA_WINDOW_SIZE:]) * 1.0 / MA_WINDOW_SIZE  # pyre-ignore
        if len(data) >= MA_WINDOW_SIZE
        else sum(data) * 1.0 / len(data)  # pyre-ignore
    )
def online_learning(agent: PearlAgent, env: Environment, number_of_episodes: Optional[int] = None, number_of_steps: Optional[int] = None, learn_after_episode: bool = False, print_every_x_episodes: Optional[int] = None, print_every_x_steps: Optional[int] = None, seed: Optional[int] = None, record_period: int = 1) ‑> Dict[str, Any]

Performs online learning for a number of episodes.

Args

agent : PearlAgent
the agent.
env : Environment
the environmnent.
number_of_episodes : int, optional
the number of episodes to run. Defaults to 1000.
learn_after_episode : bool, optional
asks the agent to only learn after every episode.

Defaults to False.

Expand source code
def online_learning(
    agent: PearlAgent,
    env: Environment,
    number_of_episodes: Optional[int] = None,
    number_of_steps: Optional[int] = None,
    learn_after_episode: bool = False,
    print_every_x_episodes: Optional[int] = None,
    print_every_x_steps: Optional[int] = None,
    seed: Optional[int] = None,
    # if number_of_episodes is used, report every record_period episodes
    # if number_of_steps is used, report every record_period steps
    # episodic stats collected within the period are averaged and then reported
    record_period: int = 1,
) -> Dict[str, Any]:
    """
    Performs online learning for a number of episodes.

    Args:
        agent (PearlAgent): the agent.
        env (Environment): the environmnent.
        number_of_episodes (int, optional): the number of episodes to run. Defaults to 1000.
        learn_after_episode (bool, optional): asks the agent to only learn after every episode.
        Defaults to False.
    """
    assert (number_of_episodes is None and number_of_steps is not None) or (
        number_of_episodes is not None and number_of_steps is None
    )
    total_steps = 0
    total_episodes = 0
    info = {}
    info_period = {}
    while True:
        if number_of_episodes is not None and total_episodes >= number_of_episodes:
            break
        if number_of_steps is not None and total_steps >= number_of_steps:
            break
        old_total_steps = total_steps
        episode_info, episode_total_steps = run_episode(
            agent,
            env,
            learn=True,
            exploit=False,
            learn_after_episode=learn_after_episode,
            total_steps=old_total_steps,
            seed=seed,
        )
        if number_of_steps is not None and episode_total_steps > record_period:
            print(
                f"An episode is longer than the report_period: episode length {episode_total_steps}"
                ", record_period {record_period}. Try using a smaller record_period."
            )
            exit(1)
        total_steps += episode_total_steps
        total_episodes += 1
        if (
            print_every_x_steps is not None
            and old_total_steps // print_every_x_steps
            < total_steps // print_every_x_steps
        ) or (
            print_every_x_episodes is not None
            and total_episodes % print_every_x_episodes == 0
        ):
            print(
                f"episode {total_episodes}, step {total_steps}, agent={agent}, env={env}",
            )
            for key in episode_info:
                print(f"{key}: {episode_info[key]}")
        for key in episode_info:
            info_period.setdefault(key, []).append(episode_info[key])
        if number_of_episodes is not None and (
            total_episodes % record_period == 0
        ):  # record average info value every report_period episodes
            for key in info_period:
                info.setdefault(key, []).append(np.mean(info_period[key]))
            info_period = {}
        if number_of_steps is not None and old_total_steps // record_period < (
            total_steps
        ) // (
            record_period
        ):  # record average info value every record_period steps
            for key in info_period:
                info.setdefault(key, []).append(np.mean(info_period[key]))
            info_period = {}
    return info
def online_learning_to_png_graph(agent: PearlAgent, env: Environment, filename: str = 'returns.png', number_of_episodes: int = 1000, learn_after_episode: bool = False) ‑> None

Runs online learning and generates a PNG graph of the returns.

Args

agent : PearlAgent
the agent.
env : Environment
the environment.
filename : str, optional
the filename to save to. Defaults to "returns.png".
number_of_episodes : int, optional
the number of episodes to run. Defaults to 1000.
learn_after_episode
if we want to learn after episode or learn every step during episode
Expand source code
def online_learning_to_png_graph(
    agent: PearlAgent,
    env: Environment,
    filename: str = "returns.png",
    number_of_episodes: int = 1000,
    learn_after_episode: bool = False,
) -> None:
    """
    Runs online learning and generates a PNG graph of the returns.

    Args:
        agent (PearlAgent): the agent.
        env (Environment): the environment.
        filename (str, optional): the filename to save to. Defaults to "returns.png".
        number_of_episodes (int, optional): the number of episodes to run. Defaults to 1000.
        learn_after_episode: if we want to learn after episode or learn every step during episode
    """

    info = online_learning(
        agent=agent,
        env=env,
        number_of_episodes=number_of_episodes,
        learn_after_episode=learn_after_episode,
    )
    assert "return" in info

    if filename is not None:
        title = f"{str(agent)} on {str(env)}"
        if len(title) > 125:
            logging.warning(
                f"Figure title is very long, with {len(title)} characters: {title}"
            )
        plt.plot(info["return"])
        plt.title(title, fontsize=fontsize_for(title))
        plt.xlabel("Episode")
        plt.ylabel("Return")
        plt.savefig(filename)
        plt.close()
def run_episode(agent: PearlAgent, env: Environment, learn: bool = False, exploit: bool = True, learn_after_episode: bool = False, total_steps: int = 0, seed: Optional[int] = None) ‑> Tuple[Dict[str, Any], int]

Runs one episode and returns an info dict and number of steps taken.

Args

agent : Agent
the agent.
env : Environment
the environment.
learn : bool, optional
Runs agent.learn() after every step. Defaults to False.
exploit : bool, optional
asks the agent to only exploit. Defaults to False.
learn_after_episode : bool, optional
asks the agent to only learn at the end of the episode. Defaults to False.

Returns

Tuple[Dict[str, Any], int]
the return of the episode and the number of steps taken.
Expand source code
def run_episode(
    agent: PearlAgent,
    env: Environment,
    learn: bool = False,
    exploit: bool = True,
    learn_after_episode: bool = False,
    total_steps: int = 0,
    seed: Optional[int] = None,
) -> Tuple[Dict[str, Any], int]:
    """
    Runs one episode and returns an info dict and number of steps taken.

    Args:
        agent (Agent): the agent.
        env (Environment): the environment.
        learn (bool, optional): Runs `agent.learn()` after every step. Defaults to False.
        exploit (bool, optional): asks the agent to only exploit. Defaults to False.
        learn_after_episode (bool, optional): asks the agent to only learn at
                                              the end of the episode. Defaults to False.

    Returns:
        Tuple[Dict[str, Any], int]: the return of the episode and the number of steps taken.
    """
    if seed is None:
        observation, action_space = env.reset(seed=seed)
    else:
        # each episode has a different seed
        observation, action_space = env.reset(seed=seed + total_steps)
    agent.reset(observation, action_space)
    cum_reward = 0
    cum_cost = 0
    done = False
    episode_steps = 0
    num_risky_sa = 0
    while not done:
        action = agent.act(exploit=exploit)
        action = (
            action.cpu() if isinstance(action, torch.Tensor) else action
        )  # action can be int sometimes
        action_result = env.step(action)
        cum_reward += action_result.reward
        if (
            num_risky_sa is not None
            and action_result.info is not None
            and "risky_sa" in action_result.info
        ):
            num_risky_sa += action_result.info["risky_sa"]
        else:
            num_risky_sa = None
        if cum_cost is not None and action_result.cost is not None:
            cum_cost += action_result.cost
        else:
            cum_cost = None
        agent.observe(action_result)
        if learn and not learn_after_episode:
            agent.learn()
        done = action_result.done
        episode_steps += 1

    if learn and learn_after_episode:
        agent.learn()

    info = {"return": cum_reward}
    if num_risky_sa is not None:
        info.update({"risky_sa_ratio": num_risky_sa / episode_steps})
    if cum_cost is not None:
        info.update({"return_cost": cum_cost})

    return info, episode_steps
def target_return_is_reached(target_return: object, max_episodes: int, agent: PearlAgent, env: Environment, learn: bool, learn_after_episode: bool, exploit: bool, required_target_returns_in_a_row: int = 1, check_moving_average: bool = False) ‑> bool

Learns until obtaining target return (a certain number of times in a row, default 1) or max_episodes are completed. Args target_return (Value): the targeted return. max_episodes (int): the maximum number of episodes to run. agent (Agent): the agent. env (Environment): the environment. learn (bool): whether to learn. learn_after_episode (bool): whether to learn after every episode. exploit (bool): whether to exploit. required_target_returns_in_a_row (int, optional): how many times we must hit the target to succeed. check_moving_average: if this is enabled, we check the if latest moving average value reaches goal Returns bool: whether target_return has been obtained required_target_returns_in_a_row times in a row.

Expand source code
def target_return_is_reached(
    target_return: Value,
    max_episodes: int,
    agent: PearlAgent,
    env: Environment,
    learn: bool,
    learn_after_episode: bool,
    exploit: bool,
    required_target_returns_in_a_row: int = 1,
    check_moving_average: bool = False,
) -> bool:
    """
    Learns until obtaining target return (a certain number of times in a row, default 1)
    or max_episodes are completed.
    Args
        target_return (Value): the targeted return.
        max_episodes (int): the maximum number of episodes to run.
        agent (Agent): the agent.
        env (Environment): the environment.
        learn (bool): whether to learn.
        learn_after_episode (bool): whether to learn after every episode.
        exploit (bool): whether to exploit.
        required_target_returns_in_a_row (int, optional): how many times we must hit the target
        to succeed.
        check_moving_average: if this is enabled, we check the if latest moving average value
                              reaches goal
    Returns
        bool: whether target_return has been obtained required_target_returns_in_a_row times
              in a row.
    """
    target_returns_in_a_row = 0
    returns = []
    total_steps = 0
    for i in range(max_episodes):
        if i % 10 == 0 and i != 0:
            print(f"episode {i} return:", returns[-1])
        episode_info, episode_total_steps = run_episode(
            agent=agent,
            env=env,
            learn=learn,
            learn_after_episode=learn_after_episode,
            exploit=exploit,
        )
        total_steps += episode_total_steps
        returns.append(episode_info["return"])
        value = (
            episode_info["return"]
            if not check_moving_average
            else latest_moving_average(returns)
        )
        if value >= target_return:
            target_returns_in_a_row += 1
            if target_returns_in_a_row >= required_target_returns_in_a_row:
                return True
        else:
            target_returns_in_a_row = 0
    return False