Module pearl.utils.instantiations.spaces.utils
Expand source code
from torch import Tensor
def reshape_to_1d_tensor(x: Tensor) -> Tensor:
"""Reshapes a Tensor that is either scalar or `1 x d` -> `d`."""
if x.ndim == 1:
return x
if x.ndim == 0: # scalar -> `d`
x = x.unsqueeze(dim=0) # `1 x d` -> `d`
elif x.ndim == 2 and x.shape[0] == 1:
x = x.squeeze(dim=0)
else:
raise ValueError(f"Tensor of shape {x.shape} is not supported.")
return x
Functions
def reshape_to_1d_tensor(x: torch.Tensor) ‑> torch.Tensor
-
Reshapes a Tensor that is either scalar or
1 x d
->d
.Expand source code
def reshape_to_1d_tensor(x: Tensor) -> Tensor: """Reshapes a Tensor that is either scalar or `1 x d` -> `d`.""" if x.ndim == 1: return x if x.ndim == 0: # scalar -> `d` x = x.unsqueeze(dim=0) # `1 x d` -> `d` elif x.ndim == 2 and x.shape[0] == 1: x = x.squeeze(dim=0) else: raise ValueError(f"Tensor of shape {x.shape} is not supported.") return x