Module pearl.utils.scripts.figure_gen

To RUN: (assume under fbcode/) first time: buck2 run pearl/scripts:figure_gen second time, something like: ../buck-out/v2/gen/fbcode/e97a0788aa35bdc8/pearl/scripts/figure_gen/figure_gen.par

Expand source code
#!/usr/bin/env fbpython
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
"""
To RUN: (assume under fbcode/)
    first time:
    buck2 run pearl/scripts:figure_gen <num_episodes>
    second time, something like:
    ../buck-out/v2/gen/fbcode/e97a0788aa35bdc8/pearl/scripts/__figure_gen__/figure_gen.par
"""
import logging
import sys
from typing import List

import matplotlib.pyplot as plt
from pearl.api.reward import Value
from pearl.pearl_agent import PearlAgent

from pearl.policy_learners.sequential_decision_making.deep_q_learning import (
    DeepQLearning,
)
from pearl.policy_learners.sequential_decision_making.deep_sarsa import DeepSARSA
from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import (
    FIFOOffPolicyReplayBuffer,
)
from pearl.replay_buffers.sequential_decision_making.fifo_on_policy_replay_buffer import (
    FIFOOnPolicyReplayBuffer,
)
from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning
from pearl.utils.instantiations.environments.gym_environment import GymEnvironment

MA_WINDOW_SIZE = 100.0


def moving_average(data: List[Value]) -> Value:
    return [
        sum(data[int(i - MA_WINDOW_SIZE + 1) : i + 1]) / MA_WINDOW_SIZE  # pyre-ignore
        if i >= MA_WINDOW_SIZE
        else sum(data[: i + 1]) * 1.0 / (i + 1)  # pyre-ignore
        for i in range(len(data))
    ]


def main() -> None:
    logging.basicConfig(level=logging.DEBUG)
    env = GymEnvironment("CartPole-v1")

    assert len(sys.argv) <= 2, "Usage: <figure_gen> <num_episodes>"
    num_episodes = 1000  # default value
    if len(sys.argv) == 2:
        num_episodes = int(sys.argv[1])

    agent = PearlAgent(
        policy_learner=DeepQLearning(
            state_dim=env.observation_space.shape[0],
            action_space=env.action_space,
            hidden_dims=[64, 64],
            training_rounds=20,
        ),
        replay_buffer=FIFOOffPolicyReplayBuffer(10000),
    )
    info = online_learning(
        agent,
        env,
        number_of_episodes=num_episodes,
        learn_after_episode=True,
    )
    plt.plot(info["return"], label="vanilla dqn")
    plt.plot(moving_average(info["return"]), label="dqn_ma")
    plt.xlabel("Episode")
    plt.ylabel("Return")
    agent = PearlAgent(
        policy_learner=DeepSARSA(
            env.observation_space.shape[0],
            env.action_space,
            hidden_dims=[64, 64],
            training_rounds=20,
        ),
        replay_buffer=FIFOOnPolicyReplayBuffer(10000),
    )
    info = online_learning(
        agent,
        env,
        number_of_episodes=num_episodes,
        learn_after_episode=True,
    )
    plt.plot(info["return"], label="sarsa")
    plt.plot(moving_average(info["return"]), label="sarsa_ma")
    plt.legend()
    plt.savefig("figure_gen.png")


if __name__ == "__main__":
    main()

Functions

def main() ‑> None
Expand source code
def main() -> None:
    logging.basicConfig(level=logging.DEBUG)
    env = GymEnvironment("CartPole-v1")

    assert len(sys.argv) <= 2, "Usage: <figure_gen> <num_episodes>"
    num_episodes = 1000  # default value
    if len(sys.argv) == 2:
        num_episodes = int(sys.argv[1])

    agent = PearlAgent(
        policy_learner=DeepQLearning(
            state_dim=env.observation_space.shape[0],
            action_space=env.action_space,
            hidden_dims=[64, 64],
            training_rounds=20,
        ),
        replay_buffer=FIFOOffPolicyReplayBuffer(10000),
    )
    info = online_learning(
        agent,
        env,
        number_of_episodes=num_episodes,
        learn_after_episode=True,
    )
    plt.plot(info["return"], label="vanilla dqn")
    plt.plot(moving_average(info["return"]), label="dqn_ma")
    plt.xlabel("Episode")
    plt.ylabel("Return")
    agent = PearlAgent(
        policy_learner=DeepSARSA(
            env.observation_space.shape[0],
            env.action_space,
            hidden_dims=[64, 64],
            training_rounds=20,
        ),
        replay_buffer=FIFOOnPolicyReplayBuffer(10000),
    )
    info = online_learning(
        agent,
        env,
        number_of_episodes=num_episodes,
        learn_after_episode=True,
    )
    plt.plot(info["return"], label="sarsa")
    plt.plot(moving_average(info["return"]), label="sarsa_ma")
    plt.legend()
    plt.savefig("figure_gen.png")
def moving_average(data: List[object]) ‑> object
Expand source code
def moving_average(data: List[Value]) -> Value:
    return [
        sum(data[int(i - MA_WINDOW_SIZE + 1) : i + 1]) / MA_WINDOW_SIZE  # pyre-ignore
        if i >= MA_WINDOW_SIZE
        else sum(data[: i + 1]) * 1.0 / (i + 1)  # pyre-ignore
        for i in range(len(data))
    ]