Module pearl.utils.functional_utils.learning.action_utils

Expand source code
from typing import Optional

import torch
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from torch import Tensor


def argmax_random_tie_breaks(
    scores: Tensor, mask: Optional[Tensor] = None
) -> torch.Tensor:
    """
    Given a 2D tensor of scores, return the indices of the max score for each row.
    If there are ties inside a row, uniformly randomize among the ties.
    IMPORTANT IMPLEMENTATION DETAILS:
        1. Randomization is implemented consistently across all rows. E.g. if several columns
            are tied on 2 different rows, we will return the same index for each of these rows.

    Args:
        scores: A 2D tensor of scores
        mask [Optional]: A 2D score presence mask.
                         If missing, assuming that all scores are unmasked.
    """
    # This function only works for 2D tensor
    assert scores.ndim == 2

    # Permute the columns
    num_cols = scores.size(1)
    random_col_indices = torch.randperm(num_cols)
    permuted_scores = torch.index_select(scores, 1, random_col_indices)
    if mask is not None:
        permuted_mask = torch.index_select(mask, 1, random_col_indices)
        permuted_scores = torch.masked.as_masked_tensor(
            permuted_scores, permuted_mask.bool()
        )

    # Find the indices of the maximum elements in the random permutation
    max_indices_in_permuted_data = torch.argmax(permuted_scores, dim=1)

    if mask is not None:
        # pyre-fixme[16]: `Tensor` has no attribute `get_data`.
        max_indices_in_permuted_data = max_indices_in_permuted_data.get_data().long()

    # Use the random permutation to get the original indices of the maximum elements
    argmax_indices = random_col_indices[max_indices_in_permuted_data]

    return argmax_indices


def get_model_actions(
    scores: Tensor,
    mask: Optional[Tensor] = None,
    randomize_ties: bool = False,
) -> torch.Tensor:
    """
    Given a tensor of scores, get the indices of chosen actions.
    Chosen actions are the score argmax (within each row), subject to optional mask.
    if `randomize_ties`=True, we will also randomize the order of tied actions with
        maximum values. This has computational cost compared to not randomizing (use 1st index)

    Args:
        scores: A 2D tensor of scores
        mask [Optional]: A 2D score presence mask.
                         If missing, assuming that all scores are unmasked.

    Returns:
        1D tensor of size (batch_size,)
    """
    if randomize_ties:
        model_actions = argmax_random_tie_breaks(scores, mask)
    else:
        if mask is None:
            # vanilla argmax - no masking or randomization
            model_actions = torch.argmax(scores, dim=1)
        else:
            # mask out non-present arms
            scores_masked = torch.masked.as_masked_tensor(scores, mask.bool())
            model_actions = (
                # pyre-fixme[16]: `Tensor` has no attribute `get_data`.
                torch.argmax(scores_masked, dim=1).get_data()
            )
    return model_actions


def concatenate_actions_to_state(
    subjective_state: Tensor,
    action_space: DiscreteActionSpace,
    state_features_only: bool = False,
) -> Tensor:
    """A helper function for concatenating all actions from a `DiscreteActionSpace`
    to a state or batch of states. The actions must be Tensors.

    Args:
        subjective_state: A Tensor of shape (batch_size, state_dim) or (state_dim).
        action_space: A `DiscreteActionSpace` object where each action is a Tensor.
        state_features_only: If True, only expand the state dimension without
            concatenating the actions.
    Returns:
        A Tensor of shape (batch_size, action_count, state_dim + action_dim).
    """
    state_dim = subjective_state.shape[-1]
    # Reshape to (batch_size, state_dim)
    subjective_state = subjective_state.view(-1, state_dim)
    batch_size = subjective_state.shape[0]

    action_dim = action_space.action_dim
    action_count = action_space.n

    # Expand to (batch_size, action_count, state_dim) and return if `state_features_only`
    expanded_state = subjective_state.unsqueeze(1).repeat(1, action_count, 1)
    if state_features_only:
        return expanded_state

    # Stack actions and expand to (batch_size, action_count, action_dim)
    actions = torch.stack(action_space.actions).to(subjective_state.device)
    expanded_action = actions.unsqueeze(0).repeat(batch_size, 1, 1)

    # (batch_size, action_count, state_dim + action_dim)
    new_feature = torch.cat([expanded_state, expanded_action], dim=2)
    torch._assert(
        new_feature.shape == (batch_size, action_count, state_dim + action_dim),
        "The shape of the concatenated feature is wrong. Expected "
        f"{(batch_size, action_count, state_dim + action_dim)}, got {new_feature.shape}",
    )
    return new_feature.to(subjective_state.device)

Functions

def argmax_random_tie_breaks(scores: torch.Tensor, mask: Optional[torch.Tensor] = None) ‑> torch.Tensor

Given a 2D tensor of scores, return the indices of the max score for each row. If there are ties inside a row, uniformly randomize among the ties. IMPORTANT IMPLEMENTATION DETAILS: 1. Randomization is implemented consistently across all rows. E.g. if several columns are tied on 2 different rows, we will return the same index for each of these rows.

Args

scores
A 2D tensor of scores

mask [Optional]: A 2D score presence mask. If missing, assuming that all scores are unmasked.

Expand source code
def argmax_random_tie_breaks(
    scores: Tensor, mask: Optional[Tensor] = None
) -> torch.Tensor:
    """
    Given a 2D tensor of scores, return the indices of the max score for each row.
    If there are ties inside a row, uniformly randomize among the ties.
    IMPORTANT IMPLEMENTATION DETAILS:
        1. Randomization is implemented consistently across all rows. E.g. if several columns
            are tied on 2 different rows, we will return the same index for each of these rows.

    Args:
        scores: A 2D tensor of scores
        mask [Optional]: A 2D score presence mask.
                         If missing, assuming that all scores are unmasked.
    """
    # This function only works for 2D tensor
    assert scores.ndim == 2

    # Permute the columns
    num_cols = scores.size(1)
    random_col_indices = torch.randperm(num_cols)
    permuted_scores = torch.index_select(scores, 1, random_col_indices)
    if mask is not None:
        permuted_mask = torch.index_select(mask, 1, random_col_indices)
        permuted_scores = torch.masked.as_masked_tensor(
            permuted_scores, permuted_mask.bool()
        )

    # Find the indices of the maximum elements in the random permutation
    max_indices_in_permuted_data = torch.argmax(permuted_scores, dim=1)

    if mask is not None:
        # pyre-fixme[16]: `Tensor` has no attribute `get_data`.
        max_indices_in_permuted_data = max_indices_in_permuted_data.get_data().long()

    # Use the random permutation to get the original indices of the maximum elements
    argmax_indices = random_col_indices[max_indices_in_permuted_data]

    return argmax_indices
def concatenate_actions_to_state(subjective_state: torch.Tensor, action_space: DiscreteActionSpace, state_features_only: bool = False) ‑> torch.Tensor

A helper function for concatenating all actions from a DiscreteActionSpace to a state or batch of states. The actions must be Tensors.

Args

subjective_state
A Tensor of shape (batch_size, state_dim) or (state_dim).
action_space
A DiscreteActionSpace object where each action is a Tensor.
state_features_only
If True, only expand the state dimension without concatenating the actions.

Returns

A Tensor of shape (batch_size, action_count, state_dim + action_dim).

Expand source code
def concatenate_actions_to_state(
    subjective_state: Tensor,
    action_space: DiscreteActionSpace,
    state_features_only: bool = False,
) -> Tensor:
    """A helper function for concatenating all actions from a `DiscreteActionSpace`
    to a state or batch of states. The actions must be Tensors.

    Args:
        subjective_state: A Tensor of shape (batch_size, state_dim) or (state_dim).
        action_space: A `DiscreteActionSpace` object where each action is a Tensor.
        state_features_only: If True, only expand the state dimension without
            concatenating the actions.
    Returns:
        A Tensor of shape (batch_size, action_count, state_dim + action_dim).
    """
    state_dim = subjective_state.shape[-1]
    # Reshape to (batch_size, state_dim)
    subjective_state = subjective_state.view(-1, state_dim)
    batch_size = subjective_state.shape[0]

    action_dim = action_space.action_dim
    action_count = action_space.n

    # Expand to (batch_size, action_count, state_dim) and return if `state_features_only`
    expanded_state = subjective_state.unsqueeze(1).repeat(1, action_count, 1)
    if state_features_only:
        return expanded_state

    # Stack actions and expand to (batch_size, action_count, action_dim)
    actions = torch.stack(action_space.actions).to(subjective_state.device)
    expanded_action = actions.unsqueeze(0).repeat(batch_size, 1, 1)

    # (batch_size, action_count, state_dim + action_dim)
    new_feature = torch.cat([expanded_state, expanded_action], dim=2)
    torch._assert(
        new_feature.shape == (batch_size, action_count, state_dim + action_dim),
        "The shape of the concatenated feature is wrong. Expected "
        f"{(batch_size, action_count, state_dim + action_dim)}, got {new_feature.shape}",
    )
    return new_feature.to(subjective_state.device)
def get_model_actions(scores: torch.Tensor, mask: Optional[torch.Tensor] = None, randomize_ties: bool = False) ‑> torch.Tensor

Given a tensor of scores, get the indices of chosen actions. Chosen actions are the score argmax (within each row), subject to optional mask. if randomize_ties=True, we will also randomize the order of tied actions with maximum values. This has computational cost compared to not randomizing (use 1st index)

Args

scores
A 2D tensor of scores

mask [Optional]: A 2D score presence mask. If missing, assuming that all scores are unmasked.

Returns

1D tensor of size (batch_size,)

Expand source code
def get_model_actions(
    scores: Tensor,
    mask: Optional[Tensor] = None,
    randomize_ties: bool = False,
) -> torch.Tensor:
    """
    Given a tensor of scores, get the indices of chosen actions.
    Chosen actions are the score argmax (within each row), subject to optional mask.
    if `randomize_ties`=True, we will also randomize the order of tied actions with
        maximum values. This has computational cost compared to not randomizing (use 1st index)

    Args:
        scores: A 2D tensor of scores
        mask [Optional]: A 2D score presence mask.
                         If missing, assuming that all scores are unmasked.

    Returns:
        1D tensor of size (batch_size,)
    """
    if randomize_ties:
        model_actions = argmax_random_tie_breaks(scores, mask)
    else:
        if mask is None:
            # vanilla argmax - no masking or randomization
            model_actions = torch.argmax(scores, dim=1)
        else:
            # mask out non-present arms
            scores_masked = torch.masked.as_masked_tensor(scores, mask.bool())
            model_actions = (
                # pyre-fixme[16]: `Tensor` has no attribute `get_data`.
                torch.argmax(scores_masked, dim=1).get_data()
            )
    return model_actions