Module pearl.neural_networks.common.utils
Expand source code
import logging
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from torch.func import stack_module_state
from .residual_wrapper import ResidualWrapper
ACTIVATION_MAP = {
"tanh": nn.Tanh,
"relu": nn.ReLU,
"leaky_relu": nn.LeakyReLU,
"linear": nn.Identity,
"sigmoid": nn.Sigmoid,
"softplus": nn.Softplus,
"softmax": nn.Softmax,
}
def mlp_block(
input_dim: int,
hidden_dims: Optional[List[int]],
output_dim: int = 1,
use_batch_norm: bool = False,
use_layer_norm: bool = False,
hidden_activation: str = "relu",
last_activation: Optional[str] = None,
dropout_ratio: float = 0.0,
use_skip_connections: bool = False,
**kwargs: Any,
) -> nn.Module:
"""
A simple MLP which can be reused to create more complex networks
Args:
input_dim: dimension of the input layer
hidden_dims: a list of dimensions of the hidden layers
output_dim: dimension of the output layer
use_batch_norm: whether to use batch_norm or not in the hidden layers
hidden_activation: activation function used for hidden layers
last_activation: this is optional, if need activation for layer, set this input
otherwise, no activation is applied on last layer
dropout_ratio: user needs to call nn.Module.eval to ensure dropout is ignored during act
Returns:
an nn.Sequential module consisting of mlp layers
"""
if hidden_dims is None:
hidden_dims = []
dims = [input_dim] + hidden_dims + [output_dim]
layers = []
for i in range(len(dims) - 2):
single_layers = []
input_dim_current_layer = dims[i]
output_dim_current_layer = dims[i + 1]
single_layers.append(
nn.Linear(input_dim_current_layer, output_dim_current_layer)
)
if use_layer_norm:
single_layers.append(nn.LayerNorm(output_dim_current_layer))
if dropout_ratio > 0:
single_layers.append(nn.Dropout(p=dropout_ratio))
single_layers.append(ACTIVATION_MAP[hidden_activation]())
if use_batch_norm:
single_layers.append(nn.BatchNorm1d(output_dim_current_layer))
single_layer_model = nn.Sequential(*single_layers)
if use_skip_connections:
if input_dim_current_layer == output_dim_current_layer:
single_layer_model = ResidualWrapper(single_layer_model)
else:
logging.warn(
"Skip connections are enabled, "
f"but layer in_dim ({input_dim_current_layer}) != out_dim "
f"({output_dim_current_layer})."
"Skip connection will not be added for this layer"
)
layers.append(single_layer_model)
last_layer = []
last_layer.append(nn.Linear(dims[-2], dims[-1]))
if last_activation is not None:
last_layer.append(ACTIVATION_MAP[last_activation]())
last_layer_model = nn.Sequential(*last_layer)
if use_skip_connections:
if dims[-2] == dims[-1]:
last_layer_model = ResidualWrapper(last_layer_model)
else:
logging.warn(
"Skip connections are enabled, "
f"but layer in_dim ({dims[-2]}) != out_dim ({dims[-1]}). "
"Skip connection will not be added for this layer"
)
layers.append(last_layer_model)
return nn.Sequential(*layers)
def conv_block(
input_channels_count: int,
output_channels_list: List[int],
kernel_sizes: List[int],
strides: List[int],
paddings: List[int],
use_batch_norm: bool = False,
) -> nn.Module:
"""
Reminder: torch.Conv2d layers expect inputs as (batch_size, in_channels, height, width)
Notes: layer norm is typically not used with CNNs
Args:
input_channels_count: number of input channels
output_channels_list: a list of number of output channels for each convolutional layer
kernel_sizes: a list of kernel sizes for each layer
strides: a list of strides for each layer
paddings: a list of paddings for each layer
use_batch_norm: whether to use batch_norm or not in the convolutional layers
Returns:
an nn.Sequential module consisting of convolutional layers
"""
layers = []
for out_channels, kernel_size, stride, padding in zip(
output_channels_list, kernel_sizes, strides, paddings
):
conv_layer = nn.Conv2d(
input_channels_count,
out_channels,
kernel_size,
stride=stride,
padding=padding,
)
layers.append(conv_layer)
if use_batch_norm and input_channels_count > 1:
layers.append(
nn.BatchNorm2d(input_channels_count)
) # input to Batchnorm 2d is the number of input channels
layers.append(nn.ReLU())
# number of input channels to next layer is number of output channels of previous layer:
input_channels_count = out_channels
return nn.Sequential(*layers)
# TODO: the name of this function needs to be revised to xavier_init_weights
def init_weights(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def uniform_init_weights(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
nn.init.uniform_(m.weight, -0.001, 0.001)
nn.init.uniform_(m.bias, -0.001, 0.001)
def update_target_network(
target_network: nn.Module, source_network: nn.Module, tau: float
) -> None:
# Q_target = (1 - tao) * Q_target + tao*Q
for target_param, source_param in zip(
target_network.parameters(), source_network.parameters()
):
if target_param is source_param:
# skip soft-updating when the target network shares the parameter with
# the network being train.
continue
new_param = tau * source_param.data + (1.0 - tau) * target_param.data
target_param.data.copy_(new_param)
def ensemble_forward(
models: Union[nn.ModuleList, List[nn.Module]],
features: torch.Tensor,
use_for_loop: bool = True,
) -> torch.Tensor:
"""
Run forward pass on several models and return their outputs stacked as a tensor.
If use_for_loop is False, a vectorized implementation is used, which has some
limitations (see https://fburl.com/code/m4l2tjof):
1. All models must have the same structure.
2. Gradient backpropagation to original model parameters might not work properly.
Args:
models: list of models to run forward pass on. Length: num_models
features: features to run forward pass on. shape: (batch_size, num_models, num_features)
use_for_loop: whether to use for loop or vectorized implementation
Output:
A tensor of shape (batch_size, num_models)
"""
torch._assert(
features.ndim == 3,
"Features should be of shape (batch_size, num_models, num_features)",
)
torch._assert(
features.shape[1] == len(models),
"Number of models must match features.shape[1]",
)
batch_size = features.shape[0]
if use_for_loop:
values = [model(features[:, i, :]).flatten() for i, model in enumerate(models)]
return torch.stack(values, dim=-1) # (batch_size, ensemble_size)
else:
features = features.permute((1, 0, 2))
def wrapper(
params: Dict[str, torch.Tensor],
buffers: Dict[str, torch.Tensor],
data: torch.Tensor,
) -> torch.Tensor:
return torch.func.functional_call(models[0], (params, buffers), data)
params, buffers = stack_module_state(models)
values = torch.vmap(wrapper)(params, buffers, features).view(
(-1, batch_size)
) # (ensemble_size, batch_size)
# change shape to (batch_size, ensemble_size)
return values.permute(1, 0)
def update_target_networks(
list_of_target_networks: Union[nn.ModuleList, List[nn.Module]],
list_of_source_networks: Union[nn.ModuleList, List[nn.Module]],
tau: float,
) -> None:
"""
Args:
list_of_target_networks: nn.ModuleList() of nn.Module()
list_of_source_networks: nn.ModuleList() of nn.Module()
tau: parameter for soft update
"""
# Q_target = (1 - tao) * Q_target + tao*Q
for target_network, source_network in zip(
list_of_target_networks, list_of_source_networks
):
update_target_network(target_network, source_network, tau)
Functions
def conv_block(input_channels_count: int, output_channels_list: List[int], kernel_sizes: List[int], strides: List[int], paddings: List[int], use_batch_norm: bool = False) ‑> torch.nn.modules.module.Module
-
Reminder: torch.Conv2d layers expect inputs as (batch_size, in_channels, height, width) Notes: layer norm is typically not used with CNNs
Args
input_channels_count
- number of input channels
output_channels_list
- a list of number of output channels for each convolutional layer
kernel_sizes
- a list of kernel sizes for each layer
strides
- a list of strides for each layer
paddings
- a list of paddings for each layer
use_batch_norm
- whether to use batch_norm or not in the convolutional layers
Returns
an nn.Sequential module consisting of convolutional layers
Expand source code
def conv_block( input_channels_count: int, output_channels_list: List[int], kernel_sizes: List[int], strides: List[int], paddings: List[int], use_batch_norm: bool = False, ) -> nn.Module: """ Reminder: torch.Conv2d layers expect inputs as (batch_size, in_channels, height, width) Notes: layer norm is typically not used with CNNs Args: input_channels_count: number of input channels output_channels_list: a list of number of output channels for each convolutional layer kernel_sizes: a list of kernel sizes for each layer strides: a list of strides for each layer paddings: a list of paddings for each layer use_batch_norm: whether to use batch_norm or not in the convolutional layers Returns: an nn.Sequential module consisting of convolutional layers """ layers = [] for out_channels, kernel_size, stride, padding in zip( output_channels_list, kernel_sizes, strides, paddings ): conv_layer = nn.Conv2d( input_channels_count, out_channels, kernel_size, stride=stride, padding=padding, ) layers.append(conv_layer) if use_batch_norm and input_channels_count > 1: layers.append( nn.BatchNorm2d(input_channels_count) ) # input to Batchnorm 2d is the number of input channels layers.append(nn.ReLU()) # number of input channels to next layer is number of output channels of previous layer: input_channels_count = out_channels return nn.Sequential(*layers)
def ensemble_forward(models: Union[torch.nn.modules.container.ModuleList, List[torch.nn.modules.module.Module]], features: torch.Tensor, use_for_loop: bool = True) ‑> torch.Tensor
-
Run forward pass on several models and return their outputs stacked as a tensor. If use_for_loop is False, a vectorized implementation is used, which has some limitations (see https://fburl.com/code/m4l2tjof): 1. All models must have the same structure. 2. Gradient backpropagation to original model parameters might not work properly.
Args
models
- list of models to run forward pass on. Length: num_models
features
- features to run forward pass on. shape: (batch_size, num_models, num_features)
use_for_loop
- whether to use for loop or vectorized implementation
Output
A tensor of shape (batch_size, num_models)
Expand source code
def ensemble_forward( models: Union[nn.ModuleList, List[nn.Module]], features: torch.Tensor, use_for_loop: bool = True, ) -> torch.Tensor: """ Run forward pass on several models and return their outputs stacked as a tensor. If use_for_loop is False, a vectorized implementation is used, which has some limitations (see https://fburl.com/code/m4l2tjof): 1. All models must have the same structure. 2. Gradient backpropagation to original model parameters might not work properly. Args: models: list of models to run forward pass on. Length: num_models features: features to run forward pass on. shape: (batch_size, num_models, num_features) use_for_loop: whether to use for loop or vectorized implementation Output: A tensor of shape (batch_size, num_models) """ torch._assert( features.ndim == 3, "Features should be of shape (batch_size, num_models, num_features)", ) torch._assert( features.shape[1] == len(models), "Number of models must match features.shape[1]", ) batch_size = features.shape[0] if use_for_loop: values = [model(features[:, i, :]).flatten() for i, model in enumerate(models)] return torch.stack(values, dim=-1) # (batch_size, ensemble_size) else: features = features.permute((1, 0, 2)) def wrapper( params: Dict[str, torch.Tensor], buffers: Dict[str, torch.Tensor], data: torch.Tensor, ) -> torch.Tensor: return torch.func.functional_call(models[0], (params, buffers), data) params, buffers = stack_module_state(models) values = torch.vmap(wrapper)(params, buffers, features).view( (-1, batch_size) ) # (ensemble_size, batch_size) # change shape to (batch_size, ensemble_size) return values.permute(1, 0)
def init_weights(m: torch.nn.modules.module.Module) ‑> None
-
Expand source code
def init_weights(m: nn.Module) -> None: if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01)
def mlp_block(input_dim: int, hidden_dims: Optional[List[int]], output_dim: int = 1, use_batch_norm: bool = False, use_layer_norm: bool = False, hidden_activation: str = 'relu', last_activation: Optional[str] = None, dropout_ratio: float = 0.0, use_skip_connections: bool = False, **kwargs: Any) ‑> torch.nn.modules.module.Module
-
A simple MLP which can be reused to create more complex networks
Args
input_dim
- dimension of the input layer
hidden_dims
- a list of dimensions of the hidden layers
output_dim
- dimension of the output layer
use_batch_norm
- whether to use batch_norm or not in the hidden layers
hidden_activation
- activation function used for hidden layers
last_activation
- this is optional, if need activation for layer, set this input otherwise, no activation is applied on last layer
dropout_ratio
- user needs to call nn.Module.eval to ensure dropout is ignored during act
Returns
an nn.Sequential module consisting of mlp layers
Expand source code
def mlp_block( input_dim: int, hidden_dims: Optional[List[int]], output_dim: int = 1, use_batch_norm: bool = False, use_layer_norm: bool = False, hidden_activation: str = "relu", last_activation: Optional[str] = None, dropout_ratio: float = 0.0, use_skip_connections: bool = False, **kwargs: Any, ) -> nn.Module: """ A simple MLP which can be reused to create more complex networks Args: input_dim: dimension of the input layer hidden_dims: a list of dimensions of the hidden layers output_dim: dimension of the output layer use_batch_norm: whether to use batch_norm or not in the hidden layers hidden_activation: activation function used for hidden layers last_activation: this is optional, if need activation for layer, set this input otherwise, no activation is applied on last layer dropout_ratio: user needs to call nn.Module.eval to ensure dropout is ignored during act Returns: an nn.Sequential module consisting of mlp layers """ if hidden_dims is None: hidden_dims = [] dims = [input_dim] + hidden_dims + [output_dim] layers = [] for i in range(len(dims) - 2): single_layers = [] input_dim_current_layer = dims[i] output_dim_current_layer = dims[i + 1] single_layers.append( nn.Linear(input_dim_current_layer, output_dim_current_layer) ) if use_layer_norm: single_layers.append(nn.LayerNorm(output_dim_current_layer)) if dropout_ratio > 0: single_layers.append(nn.Dropout(p=dropout_ratio)) single_layers.append(ACTIVATION_MAP[hidden_activation]()) if use_batch_norm: single_layers.append(nn.BatchNorm1d(output_dim_current_layer)) single_layer_model = nn.Sequential(*single_layers) if use_skip_connections: if input_dim_current_layer == output_dim_current_layer: single_layer_model = ResidualWrapper(single_layer_model) else: logging.warn( "Skip connections are enabled, " f"but layer in_dim ({input_dim_current_layer}) != out_dim " f"({output_dim_current_layer})." "Skip connection will not be added for this layer" ) layers.append(single_layer_model) last_layer = [] last_layer.append(nn.Linear(dims[-2], dims[-1])) if last_activation is not None: last_layer.append(ACTIVATION_MAP[last_activation]()) last_layer_model = nn.Sequential(*last_layer) if use_skip_connections: if dims[-2] == dims[-1]: last_layer_model = ResidualWrapper(last_layer_model) else: logging.warn( "Skip connections are enabled, " f"but layer in_dim ({dims[-2]}) != out_dim ({dims[-1]}). " "Skip connection will not be added for this layer" ) layers.append(last_layer_model) return nn.Sequential(*layers)
def uniform_init_weights(m: torch.nn.modules.module.Module) ‑> None
-
Expand source code
def uniform_init_weights(m: nn.Module) -> None: if isinstance(m, nn.Linear): nn.init.uniform_(m.weight, -0.001, 0.001) nn.init.uniform_(m.bias, -0.001, 0.001)
def update_target_network(target_network: torch.nn.modules.module.Module, source_network: torch.nn.modules.module.Module, tau: float) ‑> None
-
Expand source code
def update_target_network( target_network: nn.Module, source_network: nn.Module, tau: float ) -> None: # Q_target = (1 - tao) * Q_target + tao*Q for target_param, source_param in zip( target_network.parameters(), source_network.parameters() ): if target_param is source_param: # skip soft-updating when the target network shares the parameter with # the network being train. continue new_param = tau * source_param.data + (1.0 - tau) * target_param.data target_param.data.copy_(new_param)
def update_target_networks(list_of_target_networks: Union[torch.nn.modules.container.ModuleList, List[torch.nn.modules.module.Module]], list_of_source_networks: Union[torch.nn.modules.container.ModuleList, List[torch.nn.modules.module.Module]], tau: float) ‑> None
-
Args
list_of_target_networks
- nn.ModuleList() of nn.Module()
list_of_source_networks
- nn.ModuleList() of nn.Module()
tau
- parameter for soft update
Expand source code
def update_target_networks( list_of_target_networks: Union[nn.ModuleList, List[nn.Module]], list_of_source_networks: Union[nn.ModuleList, List[nn.Module]], tau: float, ) -> None: """ Args: list_of_target_networks: nn.ModuleList() of nn.Module() list_of_source_networks: nn.ModuleList() of nn.Module() tau: parameter for soft update """ # Q_target = (1 - tao) * Q_target + tao*Q for target_network, source_network in zip( list_of_target_networks, list_of_source_networks ): update_target_network(target_network, source_network, tau)