Module pearl.utils.device
Expand source code
import torch
import torch.distributed as dist
from pearl.utils.functional_utils.python_utils import value_of_first_item
class DeviceNotFoundInModuleError(ValueError):
pass
def get_device(module: torch.nn.Module) -> torch.device:
"""
Get the device that a module is on.
This is achieved by looking for non-empty parameters in the module and returning the
device of the first parameter found.
If no parameters are found, then we look for sub-modules and recurse down the tree
until we find a parameter or reach the end.
If we have neither parameters not sub-modules,
then a DeviceNotFoundInModuleError is raised.
"""
if (
hasattr(module, "_parameters")
and (first_parameter := value_of_first_item(module._parameters)) is not None
):
return first_parameter.device
elif (first_sub_module := value_of_first_item(module._modules)) is not None:
try:
return get_device(first_sub_module)
except DeviceNotFoundInModuleError:
raise DeviceNotFoundInModuleError(
f"Could not find a device for module {module} because it "
"has no parameters and could not find device in its first sub-module"
)
else:
raise DeviceNotFoundInModuleError(
f"Cannot determine the device for module {module}"
"because it has neither parameters nor sub-modules"
)
def get_pearl_device(device_id: int = -1) -> torch.device:
if device_id != -1:
return torch.device("cuda:" + str(device_id))
try:
# This is to pytorch distributed run, and should not affect
# original implementation of this file
local_rank = dist.get_rank()
except Exception:
local_rank = 0
return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
def is_distribution_enabled() -> bool:
return dist.is_initialized() and dist.is_available()
def get_default_device() -> torch.device:
"""
Returns the torch default device, that is,
the device on which factory methods without a `device`
specification place their tensors.
"""
return torch.tensor(0).device
Functions
def get_default_device() ‑> torch.device
-
Returns the torch default device, that is, the device on which factory methods without a
device
specification place their tensors.Expand source code
def get_default_device() -> torch.device: """ Returns the torch default device, that is, the device on which factory methods without a `device` specification place their tensors. """ return torch.tensor(0).device
def get_device(module: torch.nn.modules.module.Module) ‑> torch.device
-
Get the device that a module is on. This is achieved by looking for non-empty parameters in the module and returning the device of the first parameter found. If no parameters are found, then we look for sub-modules and recurse down the tree until we find a parameter or reach the end. If we have neither parameters not sub-modules, then a DeviceNotFoundInModuleError is raised.
Expand source code
def get_device(module: torch.nn.Module) -> torch.device: """ Get the device that a module is on. This is achieved by looking for non-empty parameters in the module and returning the device of the first parameter found. If no parameters are found, then we look for sub-modules and recurse down the tree until we find a parameter or reach the end. If we have neither parameters not sub-modules, then a DeviceNotFoundInModuleError is raised. """ if ( hasattr(module, "_parameters") and (first_parameter := value_of_first_item(module._parameters)) is not None ): return first_parameter.device elif (first_sub_module := value_of_first_item(module._modules)) is not None: try: return get_device(first_sub_module) except DeviceNotFoundInModuleError: raise DeviceNotFoundInModuleError( f"Could not find a device for module {module} because it " "has no parameters and could not find device in its first sub-module" ) else: raise DeviceNotFoundInModuleError( f"Cannot determine the device for module {module}" "because it has neither parameters nor sub-modules" )
def get_pearl_device(device_id: int = -1) ‑> torch.device
-
Expand source code
def get_pearl_device(device_id: int = -1) -> torch.device: if device_id != -1: return torch.device("cuda:" + str(device_id)) try: # This is to pytorch distributed run, and should not affect # original implementation of this file local_rank = dist.get_rank() except Exception: local_rank = 0 return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
def is_distribution_enabled() ‑> bool
-
Expand source code
def is_distribution_enabled() -> bool: return dist.is_initialized() and dist.is_available()
Classes
class DeviceNotFoundInModuleError (*args, **kwargs)
-
Inappropriate argument value (of correct type).
Expand source code
class DeviceNotFoundInModuleError(ValueError): pass
Ancestors
- builtins.ValueError
- builtins.Exception
- builtins.BaseException