Module pearl.utils.functional_utils.learning.linear_regression
Linear Regression Classes Currently used for linUCB and disjointLinUCB paper: https://arxiv.org/pdf/1003.0146.pdf
TODO: Distribution and many other production issues need to be considered Currently only contains simplest logic Before migrating to production, needs to schedule a code review to compare with ReAgent fbcode/reagent/models/disjoint_linucb_predictor.py fbcode/reagent/models/linear_regression.py
Expand source code
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
Linear Regression Classes
Currently used for linUCB and disjointLinUCB
paper: https://arxiv.org/pdf/1003.0146.pdf
TODO: Distribution and many other production issues need to be considered
Currently only contains simplest logic
Before migrating to production, needs to schedule a code review to compare with ReAgent
fbcode/reagent/models/disjoint_linucb_predictor.py
fbcode/reagent/models/linear_regression.py
"""
import logging
from typing import Optional, Tuple
import torch
from pearl.utils.device import is_distribution_enabled
from torch import nn
logger: logging.Logger = logging.getLogger(__name__)
class LinearRegression(nn.Module):
def __init__(self, feature_dim: int, l2_reg_lambda: float = 1.0) -> None:
"""
feature_dim: number of features
l2_reg_lambda: L2 regularization parameter
"""
super(LinearRegression, self).__init__()
self.register_buffer(
"_A",
l2_reg_lambda * torch.eye(feature_dim + 1), # +1 for intercept
)
self.register_buffer("_b", torch.zeros(feature_dim + 1))
self.register_buffer("_sum_weight", torch.zeros(1))
self.register_buffer(
"_inv_A",
torch.zeros(feature_dim + 1, feature_dim + 1),
)
self.register_buffer("_coefs", torch.zeros(feature_dim + 1))
self._feature_dim = feature_dim
self.distribution_enabled: bool = is_distribution_enabled()
@property
def A(self) -> torch.Tensor:
return self._A
@property
def coefs(self) -> torch.Tensor:
return self._coefs
@staticmethod
def batch_quadratic_form(x: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
"""
Compute the quadratic form x^T * A * x for a batched input x.
The calculation of pred_sigma (uncertainty) in LinUCB is done by quadratic form x^T * A^{-1} * x.
Inspired by https://stackoverflow.com/questions/18541851/calculate-vt-a-v-for-a-matrix-of-vectors-v # noqa: E501
This is a vectorized implementation of out[i] = x[i].t() @ A @ x[i]
x shape: (Batch, Feature_dim)
A shape: (Feature_dim, Feature_dim)
output shape: (Batch)
"""
return (torch.matmul(x, A) * x).sum(-1)
@staticmethod
def append_ones(x: torch.Tensor) -> torch.Tensor:
"""
Append a column of ones to x (for intercept of linear regression)
We append at the beginning along the last dimension (features)
"""
# Create a tensor of ones to append
ones = torch.ones_like(torch.select(x, dim=-1, index=0).unsqueeze(-1))
# Concatenate the input data with the tensor of ones along the last dimension
result = torch.cat((ones, x), dim=-1)
return result
@staticmethod
def matrix_inv_fallback_pinv(A: torch.Tensor) -> torch.Tensor:
"""
Try to apply regular matrix inv. If it fails, fallback to pseudo inverse
"""
try:
inv_A = torch.linalg.inv(A).contiguous()
# pyre-ignore[16]: Module `_C` has no attribute `_LinAlgError`.
except torch._C._LinAlgError as e:
logger.warning(
"Exception raised during A inversion, falling back to pseudo-inverse",
e,
)
# switch from `inv` to `pinv`
# first check if A is Hermitian (symmetric A)
A_is_hermitian = torch.allclose(A, A.T, atol=1e-4, rtol=1e-4)
# applying hermitian=True saves about 50% computations
inv_A = torch.linalg.pinv(
A,
hermitian=A_is_hermitian,
).contiguous()
return inv_A
def _validate_train_inputs(
self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = x.shape[0]
if weight is None:
weight = torch.ones_like(y)
assert x.shape == (
batch_size,
self._feature_dim,
), f"x has shape {x.shape} != {(batch_size, self._feature_dim)}"
assert y.shape == (batch_size,), f"y has shape {y.shape} != {(batch_size,)}"
assert weight.shape == (
batch_size,
), f"weight has shape {weight.shape} != {(batch_size,)}"
y = torch.unsqueeze(y, dim=1)
weight = torch.unsqueeze(weight, dim=1)
x = self.append_ones(x)
return x, y, weight
def learn_batch(
self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor]
) -> None:
"""
A <- A + x*x.t
b <- b + r*x
"""
# this also appends a column of ones to `x`
x, y, weight = self._validate_train_inputs(x, y, weight)
delta_A = torch.matmul(x.t(), x * weight)
delta_b = torch.matmul(x.t(), y * weight).squeeze()
delta_sum_weight = weight.sum()
if self.distribution_enabled:
torch.distributed.all_reduce(delta_A)
torch.distributed.all_reduce(delta_b)
torch.distributed.all_reduce(delta_sum_weight)
self._A += delta_A.to(self._A.device)
self._b += delta_b.to(self._b.device)
self._sum_weight += delta_sum_weight.to(self._sum_weight.device)
self.calculate_coefs() # update coefs after updating A and b
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x could be a single vector or a batch
If x is a batch, it will be shape(batch_size, ...)
return will be shape(batch_size)
"""
x = self.append_ones(x)
return torch.matmul(x, self._coefs.t())
def calculate_coefs(self) -> None:
"""
Calculate coefficients based on current A and b.
Save inverted A and coefficients in buffers.
"""
self._inv_A = self.matrix_inv_fallback_pinv(self._A)
self._coefs = torch.matmul(self._inv_A, self._b)
def calculate_sigma(self, x: torch.Tensor) -> torch.Tensor:
x = self.append_ones(x) # append a column of ones for intercept
sigma = torch.sqrt(self.batch_quadratic_form(x, self._inv_A))
return sigma
def __str__(self) -> str:
return f"LinearRegression(A:\n{self._A}\nb:\n{self._b})"
Classes
class LinearRegression (feature_dim: int, l2_reg_lambda: float = 1.0)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
feature_dim: number of features l2_reg_lambda: L2 regularization parameter
Expand source code
class LinearRegression(nn.Module): def __init__(self, feature_dim: int, l2_reg_lambda: float = 1.0) -> None: """ feature_dim: number of features l2_reg_lambda: L2 regularization parameter """ super(LinearRegression, self).__init__() self.register_buffer( "_A", l2_reg_lambda * torch.eye(feature_dim + 1), # +1 for intercept ) self.register_buffer("_b", torch.zeros(feature_dim + 1)) self.register_buffer("_sum_weight", torch.zeros(1)) self.register_buffer( "_inv_A", torch.zeros(feature_dim + 1, feature_dim + 1), ) self.register_buffer("_coefs", torch.zeros(feature_dim + 1)) self._feature_dim = feature_dim self.distribution_enabled: bool = is_distribution_enabled() @property def A(self) -> torch.Tensor: return self._A @property def coefs(self) -> torch.Tensor: return self._coefs @staticmethod def batch_quadratic_form(x: torch.Tensor, A: torch.Tensor) -> torch.Tensor: """ Compute the quadratic form x^T * A * x for a batched input x. The calculation of pred_sigma (uncertainty) in LinUCB is done by quadratic form x^T * A^{-1} * x. Inspired by https://stackoverflow.com/questions/18541851/calculate-vt-a-v-for-a-matrix-of-vectors-v # noqa: E501 This is a vectorized implementation of out[i] = x[i].t() @ A @ x[i] x shape: (Batch, Feature_dim) A shape: (Feature_dim, Feature_dim) output shape: (Batch) """ return (torch.matmul(x, A) * x).sum(-1) @staticmethod def append_ones(x: torch.Tensor) -> torch.Tensor: """ Append a column of ones to x (for intercept of linear regression) We append at the beginning along the last dimension (features) """ # Create a tensor of ones to append ones = torch.ones_like(torch.select(x, dim=-1, index=0).unsqueeze(-1)) # Concatenate the input data with the tensor of ones along the last dimension result = torch.cat((ones, x), dim=-1) return result @staticmethod def matrix_inv_fallback_pinv(A: torch.Tensor) -> torch.Tensor: """ Try to apply regular matrix inv. If it fails, fallback to pseudo inverse """ try: inv_A = torch.linalg.inv(A).contiguous() # pyre-ignore[16]: Module `_C` has no attribute `_LinAlgError`. except torch._C._LinAlgError as e: logger.warning( "Exception raised during A inversion, falling back to pseudo-inverse", e, ) # switch from `inv` to `pinv` # first check if A is Hermitian (symmetric A) A_is_hermitian = torch.allclose(A, A.T, atol=1e-4, rtol=1e-4) # applying hermitian=True saves about 50% computations inv_A = torch.linalg.pinv( A, hermitian=A_is_hermitian, ).contiguous() return inv_A def _validate_train_inputs( self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size = x.shape[0] if weight is None: weight = torch.ones_like(y) assert x.shape == ( batch_size, self._feature_dim, ), f"x has shape {x.shape} != {(batch_size, self._feature_dim)}" assert y.shape == (batch_size,), f"y has shape {y.shape} != {(batch_size,)}" assert weight.shape == ( batch_size, ), f"weight has shape {weight.shape} != {(batch_size,)}" y = torch.unsqueeze(y, dim=1) weight = torch.unsqueeze(weight, dim=1) x = self.append_ones(x) return x, y, weight def learn_batch( self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] ) -> None: """ A <- A + x*x.t b <- b + r*x """ # this also appends a column of ones to `x` x, y, weight = self._validate_train_inputs(x, y, weight) delta_A = torch.matmul(x.t(), x * weight) delta_b = torch.matmul(x.t(), y * weight).squeeze() delta_sum_weight = weight.sum() if self.distribution_enabled: torch.distributed.all_reduce(delta_A) torch.distributed.all_reduce(delta_b) torch.distributed.all_reduce(delta_sum_weight) self._A += delta_A.to(self._A.device) self._b += delta_b.to(self._b.device) self._sum_weight += delta_sum_weight.to(self._sum_weight.device) self.calculate_coefs() # update coefs after updating A and b def forward(self, x: torch.Tensor) -> torch.Tensor: """ x could be a single vector or a batch If x is a batch, it will be shape(batch_size, ...) return will be shape(batch_size) """ x = self.append_ones(x) return torch.matmul(x, self._coefs.t()) def calculate_coefs(self) -> None: """ Calculate coefficients based on current A and b. Save inverted A and coefficients in buffers. """ self._inv_A = self.matrix_inv_fallback_pinv(self._A) self._coefs = torch.matmul(self._inv_A, self._b) def calculate_sigma(self, x: torch.Tensor) -> torch.Tensor: x = self.append_ones(x) # append a column of ones for intercept sigma = torch.sqrt(self.batch_quadratic_form(x, self._inv_A)) return sigma def __str__(self) -> str: return f"LinearRegression(A:\n{self._A}\nb:\n{self._b})"
Ancestors
- torch.nn.modules.module.Module
Static methods
def append_ones(x: torch.Tensor) ‑> torch.Tensor
-
Append a column of ones to x (for intercept of linear regression) We append at the beginning along the last dimension (features)
Expand source code
@staticmethod def append_ones(x: torch.Tensor) -> torch.Tensor: """ Append a column of ones to x (for intercept of linear regression) We append at the beginning along the last dimension (features) """ # Create a tensor of ones to append ones = torch.ones_like(torch.select(x, dim=-1, index=0).unsqueeze(-1)) # Concatenate the input data with the tensor of ones along the last dimension result = torch.cat((ones, x), dim=-1) return result
def batch_quadratic_form(x: torch.Tensor, A: torch.Tensor) ‑> torch.Tensor
-
Compute the quadratic form x^T * A * x for a batched input x. The calculation of pred_sigma (uncertainty) in LinUCB is done by quadratic form x^T * A^{-1} * x. Inspired by https://stackoverflow.com/questions/18541851/calculate-vt-a-v-for-a-matrix-of-vectors-v # noqa: E501 This is a vectorized implementation of out[i] = x[i].t() @ A @ x[i] x shape: (Batch, Feature_dim) A shape: (Feature_dim, Feature_dim) output shape: (Batch)
Expand source code
@staticmethod def batch_quadratic_form(x: torch.Tensor, A: torch.Tensor) -> torch.Tensor: """ Compute the quadratic form x^T * A * x for a batched input x. The calculation of pred_sigma (uncertainty) in LinUCB is done by quadratic form x^T * A^{-1} * x. Inspired by https://stackoverflow.com/questions/18541851/calculate-vt-a-v-for-a-matrix-of-vectors-v # noqa: E501 This is a vectorized implementation of out[i] = x[i].t() @ A @ x[i] x shape: (Batch, Feature_dim) A shape: (Feature_dim, Feature_dim) output shape: (Batch) """ return (torch.matmul(x, A) * x).sum(-1)
def matrix_inv_fallback_pinv(A: torch.Tensor) ‑> torch.Tensor
-
Try to apply regular matrix inv. If it fails, fallback to pseudo inverse
Expand source code
@staticmethod def matrix_inv_fallback_pinv(A: torch.Tensor) -> torch.Tensor: """ Try to apply regular matrix inv. If it fails, fallback to pseudo inverse """ try: inv_A = torch.linalg.inv(A).contiguous() # pyre-ignore[16]: Module `_C` has no attribute `_LinAlgError`. except torch._C._LinAlgError as e: logger.warning( "Exception raised during A inversion, falling back to pseudo-inverse", e, ) # switch from `inv` to `pinv` # first check if A is Hermitian (symmetric A) A_is_hermitian = torch.allclose(A, A.T, atol=1e-4, rtol=1e-4) # applying hermitian=True saves about 50% computations inv_A = torch.linalg.pinv( A, hermitian=A_is_hermitian, ).contiguous() return inv_A
Instance variables
var A : torch.Tensor
-
Expand source code
@property def A(self) -> torch.Tensor: return self._A
var coefs : torch.Tensor
-
Expand source code
@property def coefs(self) -> torch.Tensor: return self._coefs
Methods
def calculate_coefs(self) ‑> None
-
Calculate coefficients based on current A and b. Save inverted A and coefficients in buffers.
Expand source code
def calculate_coefs(self) -> None: """ Calculate coefficients based on current A and b. Save inverted A and coefficients in buffers. """ self._inv_A = self.matrix_inv_fallback_pinv(self._A) self._coefs = torch.matmul(self._inv_A, self._b)
def calculate_sigma(self, x: torch.Tensor) ‑> torch.Tensor
-
Expand source code
def calculate_sigma(self, x: torch.Tensor) -> torch.Tensor: x = self.append_ones(x) # append a column of ones for intercept sigma = torch.sqrt(self.batch_quadratic_form(x, self._inv_A)) return sigma
def forward(self, x: torch.Tensor) ‑> torch.Tensor
-
x could be a single vector or a batch If x is a batch, it will be shape(batch_size, …) return will be shape(batch_size)
Expand source code
def forward(self, x: torch.Tensor) -> torch.Tensor: """ x could be a single vector or a batch If x is a batch, it will be shape(batch_size, ...) return will be shape(batch_size) """ x = self.append_ones(x) return torch.matmul(x, self._coefs.t())
def learn_batch(self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor]) ‑> None
-
A <- A + xx.t b <- b + rx
Expand source code
def learn_batch( self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] ) -> None: """ A <- A + x*x.t b <- b + r*x """ # this also appends a column of ones to `x` x, y, weight = self._validate_train_inputs(x, y, weight) delta_A = torch.matmul(x.t(), x * weight) delta_b = torch.matmul(x.t(), y * weight).squeeze() delta_sum_weight = weight.sum() if self.distribution_enabled: torch.distributed.all_reduce(delta_A) torch.distributed.all_reduce(delta_b) torch.distributed.all_reduce(delta_sum_weight) self._A += delta_A.to(self._A.device) self._b += delta_b.to(self._b.device) self._sum_weight += delta_sum_weight.to(self._sum_weight.device) self.calculate_coefs() # update coefs after updating A and b