Module pearl.utils.functional_utils.learning.extend_state_feature

Expand source code
#!/usr/bin/env fbpython
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import torch


def extend_state_feature_by_available_action_space(
    state_batch: torch.Tensor,
    curr_available_actions_batch: torch.Tensor,
) -> torch.Tensor:
    """
    This is a helper function.

    Input dim:
    state_batch: batch_size x state_dim
    curr_available_actions_batch: batch_size x available_action_space_size x action_dim

    Output dim:
    state_available_actions_batch: batch_size x available_action_space_size x state_dim
    """

    state_repeated_batch = torch.repeat_interleave(
        state_batch.unsqueeze(1),
        curr_available_actions_batch.shape[-2],  # num of available actions
        dim=1,
    )  # (batch_size x available_action_space_size x state_dim)

    """
    The above step adds one more dimension (number of available actions) and extends state_batch
    which is a 2d tensor of shape (batch_size x state_dim) to 3d tensor of shape
    (batch_size x available_action_space_size x state_dim).

    How: it adds a dimension and repeats state features available_action_space_size times

    Example:
    state_batch: [1 x 4] --> state_repeated_batch: [1 x 2 x 4], where batch_size=1, state_dim=4,
    and available_action_space_size=2. Hence, 2 actions are added into dim=-2
    [[1,2,3,4]] --> [[[1,2,3,4],
                      [1,2,3,4]]]
    """

    return state_repeated_batch

Functions

def extend_state_feature_by_available_action_space(state_batch: torch.Tensor, curr_available_actions_batch: torch.Tensor) ‑> torch.Tensor

This is a helper function.

Input dim: state_batch: batch_size x state_dim curr_available_actions_batch: batch_size x available_action_space_size x action_dim

Output dim: state_available_actions_batch: batch_size x available_action_space_size x state_dim

Expand source code
def extend_state_feature_by_available_action_space(
    state_batch: torch.Tensor,
    curr_available_actions_batch: torch.Tensor,
) -> torch.Tensor:
    """
    This is a helper function.

    Input dim:
    state_batch: batch_size x state_dim
    curr_available_actions_batch: batch_size x available_action_space_size x action_dim

    Output dim:
    state_available_actions_batch: batch_size x available_action_space_size x state_dim
    """

    state_repeated_batch = torch.repeat_interleave(
        state_batch.unsqueeze(1),
        curr_available_actions_batch.shape[-2],  # num of available actions
        dim=1,
    )  # (batch_size x available_action_space_size x state_dim)

    """
    The above step adds one more dimension (number of available actions) and extends state_batch
    which is a 2d tensor of shape (batch_size x state_dim) to 3d tensor of shape
    (batch_size x available_action_space_size x state_dim).

    How: it adds a dimension and repeats state features available_action_space_size times

    Example:
    state_batch: [1 x 4] --> state_repeated_batch: [1 x 2 x 4], where batch_size=1, state_dim=4,
    and available_action_space_size=2. Hence, 2 actions are added into dim=-2
    [[1,2,3,4]] --> [[[1,2,3,4],
                      [1,2,3,4]]]
    """

    return state_repeated_batch