Module pearl.user_envs.wrappers.safety
Expand source code
# pyre-ignore-all-errors
try:
import gymnasium as gym
except ModuleNotFoundError:
print("gymnasium module not found.")
class PuckWorldSafetyWrapper(gym.Wrapper):
r"""Safety wrapper for the PuckWorld environment.
Small positive reward with high variance when x > width/2
"""
def __init__(self, env, sigma=0.1):
super(PuckWorldSafetyWrapper, self).__init__(env)
self.sigma = sigma
def step(self, action):
obs, reward, done, truncated, info = self.env.step(action)
x = obs[0]
safety_reward = 0
info["risky_sa"] = 0
if x > self.env.game.width / 2:
safety_reward = self.env.np_random.normal(0.01, self.sigma)
info["risky_sa"] = 1
return obs, reward + safety_reward, done, truncated, info
Classes
class PuckWorldSafetyWrapper (env, sigma=0.1)
-
Safety wrapper for the PuckWorld environment. Small positive reward with high variance when x > width/2
Wraps an environment to allow a modular transformation of the :meth:
step
and :meth:reset
methods.Args
env
- The environment to wrap
Expand source code
class PuckWorldSafetyWrapper(gym.Wrapper): r"""Safety wrapper for the PuckWorld environment. Small positive reward with high variance when x > width/2 """ def __init__(self, env, sigma=0.1): super(PuckWorldSafetyWrapper, self).__init__(env) self.sigma = sigma def step(self, action): obs, reward, done, truncated, info = self.env.step(action) x = obs[0] safety_reward = 0 info["risky_sa"] = 0 if x > self.env.game.width / 2: safety_reward = self.env.np_random.normal(0.01, self.sigma) info["risky_sa"] = 1 return obs, reward + safety_reward, done, truncated, info
Ancestors
- gymnasium.core.Wrapper
- gymnasium.core.Env
- typing.Generic
Methods
def step(self, action)
-
Uses the :meth:
step
of the :attr:env
that can be overwritten to change the returned data.Expand source code
def step(self, action): obs, reward, done, truncated, info = self.env.step(action) x = obs[0] safety_reward = 0 info["risky_sa"] = 0 if x > self.env.game.width / 2: safety_reward = self.env.np_random.normal(0.01, self.sigma) info["risky_sa"] = 1 return obs, reward + safety_reward, done, truncated, info