Module pearl.user_envs.wrappers.safety

Expand source code
# pyre-ignore-all-errors

try:
    import gymnasium as gym
except ModuleNotFoundError:
    print("gymnasium module not found.")


class PuckWorldSafetyWrapper(gym.Wrapper):
    r"""Safety wrapper for the PuckWorld environment.
    Small positive reward with high variance when x > width/2
    """

    def __init__(self, env, sigma=0.1):
        super(PuckWorldSafetyWrapper, self).__init__(env)
        self.sigma = sigma

    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)
        x = obs[0]
        safety_reward = 0
        info["risky_sa"] = 0
        if x > self.env.game.width / 2:
            safety_reward = self.env.np_random.normal(0.01, self.sigma)
            info["risky_sa"] = 1
        return obs, reward + safety_reward, done, truncated, info

Classes

class PuckWorldSafetyWrapper (env, sigma=0.1)

Safety wrapper for the PuckWorld environment. Small positive reward with high variance when x > width/2

Wraps an environment to allow a modular transformation of the :meth:step and :meth:reset methods.

Args

env
The environment to wrap
Expand source code
class PuckWorldSafetyWrapper(gym.Wrapper):
    r"""Safety wrapper for the PuckWorld environment.
    Small positive reward with high variance when x > width/2
    """

    def __init__(self, env, sigma=0.1):
        super(PuckWorldSafetyWrapper, self).__init__(env)
        self.sigma = sigma

    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)
        x = obs[0]
        safety_reward = 0
        info["risky_sa"] = 0
        if x > self.env.game.width / 2:
            safety_reward = self.env.np_random.normal(0.01, self.sigma)
            info["risky_sa"] = 1
        return obs, reward + safety_reward, done, truncated, info

Ancestors

  • gymnasium.core.Wrapper
  • gymnasium.core.Env
  • typing.Generic

Methods

def step(self, action)

Uses the :meth:step of the :attr:env that can be overwritten to change the returned data.

Expand source code
def step(self, action):
    obs, reward, done, truncated, info = self.env.step(action)
    x = obs[0]
    safety_reward = 0
    info["risky_sa"] = 0
    if x > self.env.game.width / 2:
        safety_reward = self.env.np_random.normal(0.01, self.sigma)
        info["risky_sa"] = 1
    return obs, reward + safety_reward, done, truncated, info