Skip to content

Module wtracker.neural.mlp

View Source
from torch import Tensor, nn

from typing import Union, Sequence

from collections import defaultdict

from wtracker.neural.config import IOConfig

ACTIVATIONS = {

    "relu": nn.ReLU,

    "tanh": nn.Tanh,

    "sigmoid": nn.Sigmoid,

    "softmax": nn.Softmax,

    "logsoftmax": nn.LogSoftmax,

    "lrelu": nn.LeakyReLU,

    "none": nn.Identity,

    None: nn.Identity,

}

# Default keyword arguments to pass to activation class constructors, e.g.

# activation_cls(**ACTIVATION_DEFAULT_KWARGS[name])

ACTIVATION_DEFAULT_KWARGS = defaultdict(

    dict,

    {

        ###

        "softmax": dict(dim=1),

        "logsoftmax": dict(dim=1),

    },

)

class WormPredictor(nn.Module):

    """

    A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class

    so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model).

    This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output.

    Attributes:

        model: The neural network model that predicts worm behavior.

        io_config: The IOConfig object of the model.

    """

    def __init__(self, model: nn.Module, io_config: IOConfig):

        super().__init__()

        self.io_config: IOConfig = io_config

        self.model: nn.Module = model

    def forward(self, x: Tensor) -> Tensor:

        return self.model(x)

class MLPLayer(nn.Module):

    """

    A single layer perceptron, that can hold a bach-norm and activation layers as well.

    """

    def __init__(

        self,

        in_dim: int,

        out_dim: Sequence[int],

        nonlin: Union[str, nn.Module],

        batch_norm: bool = True,

    ) -> None:

        super().__init__()

        layers = []

        layers.append(nn.Linear(in_dim, out_dim))

        in_dim = out_dim

        if batch_norm and nonlin not in ["none", None]:

            layers.append(nn.BatchNorm1d(out_dim))

        layers.append(self._make_activation(nonlin))

        self.mlp_layer = nn.Sequential(*layers)

    def _make_activation(self, act: Union[str, nn.Module]) -> nn.Module:

        if isinstance(act, str):

            return ACTIVATIONS[act](**ACTIVATION_DEFAULT_KWARGS[act])

        return act

    def forward(self, x: Tensor) -> Tensor:

        """

        Args:

            x: An input tensor, of shape (N, D) containing N samples with D features.

        Returns:

            An output tensor of shape (N, D_out) where D_out is the output dim.

        """

        return self.mlp_layer.forward(x.reshape(x.size(0), -1))

class MlpBlock(nn.Module):

    """

    A general-purpose MLP.

    Args:

        in_dim: Input dimension.

        dims: Hidden dimensions, including output dimension.

        nonlins: Non-linearities to apply after each one of the hidden

            dimensions.

            Can be either a sequence of strings which are keys in the ACTIVATIONS

            dict, or instances of nn.Module (e.g. an instance of nn.ReLU()).

            Length should match 'dims'.

    """

    def __init__(

        self,

        in_dim: int,

        dims: Sequence[int],

        nonlins: Sequence[Union[str, nn.Module]],

        batch_norm: bool = True,

    ):

        assert len(nonlins) == len(dims)

        self.in_dim = in_dim

        self.out_dim = dims[-1]

        self.dims = dims

        self.nonlins = nonlins

        super().__init__()

        layers = []

        for i, out_dim in enumerate(self.dims):

            layers.append(MLPLayer(in_dim, out_dim, nonlins[i], batch_norm))

            in_dim = out_dim

        self.sequence = nn.Sequential(*layers)

    def _make_activation(self, act: Union[str, nn.Module]) -> nn.Module:

        if isinstance(act, str):

            return ACTIVATIONS[act](**ACTIVATION_DEFAULT_KWARGS[act])

        return act

    def forward(self, x: Tensor) -> Tensor:

        """

        Args:

            x: An input tensor, of shape (N, D) containing N samples with D features.

        Returns:

            An output tensor of shape (N, D_out) where D_out is the output dim.

        """

        return self.sequence.forward(x.reshape(x.size(0), -1))

class RMLP(nn.Module):

    def __init__(

        self,

        block_in_dim: int,

        block_dims: Sequence[int],

        block_nonlins: Sequence[Union[str, nn.Module]],

        n_blocks: int,

        out_dim: int,

        in_dim: int = None,  # if in_dim is an int, then a first layer will be made

        batch_norm: bool = True,

    ) -> None:

        super().__init__()

        # Create first layer if in_dim is not None

        self.input = nn.Identity()

        if in_dim is not None:

            self.input = MLPLayer(in_dim, block_in_dim, block_nonlins[0], batch_norm)

        # Create blocks

        layers = []

        for i in range(n_blocks):

            layers.append(MlpBlock(block_in_dim, block_dims, block_nonlins, batch_norm))

        self.blocks = nn.ModuleList(layers)

        # Create output layer

        self.output = nn.Linear(block_dims[-1], out_dim)

    def _make_activation(self, act: Union[str, nn.Module]) -> nn.Module:

        if isinstance(act, str):

            return ACTIVATIONS[act](**ACTIVATION_DEFAULT_KWARGS[act])

        return act

    def forward(self, x: Tensor) -> Tensor:

        """

        Args:

            x: An input tensor, of shape (N, D) containing N samples with D features.

        Returns:

            An output tensor of shape (N, D_out) where D_out is the output dim.

        """

        x = self.input(x)

        for block in self.blocks:

            out = block(x)

            x = x + out

        return self.output(x)

Variables

ACTIVATIONS
ACTIVATION_DEFAULT_KWARGS

Classes

MLPLayer

class MLPLayer(
    in_dim: int,
    out_dim: Sequence[int],
    nonlin: Union[str, torch.nn.modules.module.Module],
    batch_norm: bool = True
)

A single layer perceptron, that can hold a bach-norm and activation layers as well.

View Source
class MLPLayer(nn.Module):

    """

    A single layer perceptron, that can hold a bach-norm and activation layers as well.

    """

    def __init__(

        self,

        in_dim: int,

        out_dim: Sequence[int],

        nonlin: Union[str, nn.Module],

        batch_norm: bool = True,

    ) -> None:

        super().__init__()

        layers = []

        layers.append(nn.Linear(in_dim, out_dim))

        in_dim = out_dim

        if batch_norm and nonlin not in ["none", None]:

            layers.append(nn.BatchNorm1d(out_dim))

        layers.append(self._make_activation(nonlin))

        self.mlp_layer = nn.Sequential(*layers)

    def _make_activation(self, act: Union[str, nn.Module]) -> nn.Module:

        if isinstance(act, str):

            return ACTIVATIONS[act](**ACTIVATION_DEFAULT_KWARGS[act])

        return act

    def forward(self, x: Tensor) -> Tensor:

        """

        Args:

            x: An input tensor, of shape (N, D) containing N samples with D features.

        Returns:

            An output tensor of shape (N, D_out) where D_out is the output dim.

        """

        return self.mlp_layer.forward(x.reshape(x.size(0), -1))

Ancestors (in MRO)

  • torch.nn.modules.module.Module

Class variables

T_destination
call_super_init
dump_patches

Methods

add_module

def add_module(
    self,
    name: str,
    module: Optional[ForwardRef('Module')]
) -> None

Add a child module to the current module.

The module can be accessed as an attribute using the given name.

Parameters:

Name Type Description Default
name str name of the child module. The child module can be
accessed from this module using the given name
None
module Module child module to be added to the module. None
View Source
    def add_module(self, name: str, module: Optional['Module']) -> None:

        r"""Add a child module to the current module.

        The module can be accessed as an attribute using the given name.

        Args:

            name (str): name of the child module. The child module can be

                accessed from this module using the given name

            module (Module): child module to be added to the module.

        """

        if not isinstance(module, Module) and module is not None:

            raise TypeError(f"{torch.typename(module)} is not a Module subclass")

        elif not isinstance(name, str):

            raise TypeError(f"module name should be a string. Got {torch.typename(name)}")

        elif hasattr(self, name) and name not in self._modules:

            raise KeyError(f"attribute '{name}' already exists")

        elif '.' in name:

            raise KeyError(f"module name can't contain \".\", got: {name}")

        elif name == '':

            raise KeyError("module name can't be empty string \"\"")

        for hook in _global_module_registration_hooks.values():

            output = hook(self, name, module)

            if output is not None:

                module = output

        self._modules[name] = module

apply

def apply(
    self: ~T,
    fn: Callable[[ForwardRef('Module')], NoneType]
) -> ~T

Apply fn recursively to every submodule (as returned by .children()) as well as self.

Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).

Parameters:

Name Type Description Default
fn ( None class:Module -> None): function to be applied to each submodule None

Returns:

Type Description
Module self
View Source
    def apply(self: T, fn: Callable[['Module'], None]) -> T:

        r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.

        Typical use includes initializing the parameters of a model

        (see also :ref:`nn-init-doc`).

        Args:

            fn (:class:`Module` -> None): function to be applied to each submodule

        Returns:

            Module: self

        Example::

            >>> @torch.no_grad()

            >>> def init_weights(m):

            >>>     print(m)

            >>>     if type(m) == nn.Linear:

            >>>         m.weight.fill_(1.0)

            >>>         print(m.weight)

            >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

            >>> net.apply(init_weights)

            Linear(in_features=2, out_features=2, bias=True)

            Parameter containing:

            tensor([[1., 1.],

                    [1., 1.]], requires_grad=True)

            Linear(in_features=2, out_features=2, bias=True)

            Parameter containing:

            tensor([[1., 1.],

                    [1., 1.]], requires_grad=True)

            Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            )

        """

        for module in self.children():

            module.apply(fn)

        fn(self)

        return self

bfloat16

def bfloat16(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def bfloat16(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)

buffers

def buffers(
    self,
    recurse: bool = True
) -> Iterator[torch.Tensor]

Return an iterator over module buffers.

Parameters:

Name Type Description Default
recurse bool if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
None

Yields:

Type Description
torch.Tensor module buffer
View Source
    def buffers(self, recurse: bool = True) -> Iterator[Tensor]:

        r"""Return an iterator over module buffers.

        Args:

            recurse (bool): if True, then yields buffers of this module

                and all submodules. Otherwise, yields only buffers that

                are direct members of this module.

        Yields:

            torch.Tensor: module buffer

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for buf in model.buffers():

            >>>     print(type(buf), buf.size())

            <class 'torch.Tensor'> (20L,)

            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

        """

        for _, buf in self.named_buffers(recurse=recurse):

            yield buf

children

def children(
    self
) -> Iterator[ForwardRef('Module')]

Return an iterator over immediate children modules.

Yields:

Type Description
Module a child module
View Source
    def children(self) -> Iterator['Module']:

        r"""Return an iterator over immediate children modules.

        Yields:

            Module: a child module

        """

        for name, module in self.named_children():

            yield module

compile

def compile(
    self,
    *args,
    **kwargs
)

Compile this Module's forward using :func:torch.compile.

This Module's __call__ method is compiled and all arguments are passed as-is to :func:torch.compile.

See :func:torch.compile for details on the arguments for this function.

View Source
    def compile(self, *args, **kwargs):

        """

        Compile this Module's forward using :func:`torch.compile`.

        This Module's `__call__` method is compiled and all arguments are passed as-is

        to :func:`torch.compile`.

        See :func:`torch.compile` for details on the arguments for this function.

        """

        self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs)

cpu

def cpu(
    self: ~T
) -> ~T

Move all model parameters and buffers to the CPU.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def cpu(self: T) -> T:

        r"""Move all model parameters and buffers to the CPU.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.cpu())

cuda

def cuda(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the GPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on GPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Args:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.cuda(device))

double

def double(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to double datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def double(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``double`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.double() if t.is_floating_point() else t)

eval

def eval(
    self: ~T
) -> ~T

Set the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

This is equivalent with :meth:self.train(False) <torch.nn.Module.train>.

See :ref:locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it.

Returns:

Type Description
Module self
View Source
    def eval(self: T) -> T:

        r"""Set the module in evaluation mode.

        This has any effect only on certain modules. See documentations of

        particular modules for details of their behaviors in training/evaluation

        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,

        etc.

        This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.

        See :ref:`locally-disable-grad-doc` for a comparison between

        `.eval()` and several similar mechanisms that may be confused with it.

        Returns:

            Module: self

        """

        return self.train(False)

extra_repr

def extra_repr(
    self
) -> str

Set the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

View Source
    def extra_repr(self) -> str:

        r"""Set the extra representation of the module.

        To print customized extra information, you should re-implement

        this method in your own modules. Both single-line and multi-line

        strings are acceptable.

        """

        return ''

float

def float(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to float datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def float(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``float`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.float() if t.is_floating_point() else t)

forward

def forward(
    self,
    x: torch.Tensor
) -> torch.Tensor

Parameters:

Name Type Description Default
x None An input tensor, of shape (N, D) containing N samples with D features. None

Returns:

Type Description
None An output tensor of shape (N, D_out) where D_out is the output dim.
View Source
    def forward(self, x: Tensor) -> Tensor:

        """

        Args:

            x: An input tensor, of shape (N, D) containing N samples with D features.

        Returns:

            An output tensor of shape (N, D_out) where D_out is the output dim.

        """

        return self.mlp_layer.forward(x.reshape(x.size(0), -1))

get_buffer

def get_buffer(
    self,
    target: str
) -> 'Tensor'

Return the buffer given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the buffer
to look for. (See get_submodule for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.Tensor The buffer referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not a
buffer
View Source
    def get_buffer(self, target: str) -> "Tensor":

        """Return the buffer given by ``target`` if it exists, otherwise throw an error.

        See the docstring for ``get_submodule`` for a more detailed

        explanation of this method's functionality as well as how to

        correctly specify ``target``.

        Args:

            target: The fully-qualified string name of the buffer

                to look for. (See ``get_submodule`` for how to specify a

                fully-qualified string.)

        Returns:

            torch.Tensor: The buffer referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not a

                buffer

        """

        module_path, _, buffer_name = target.rpartition(".")

        mod: torch.nn.Module = self.get_submodule(module_path)

        if not hasattr(mod, buffer_name):

            raise AttributeError(mod._get_name() + " has no attribute `"

                                 + buffer_name + "`")

        buffer: torch.Tensor = getattr(mod, buffer_name)

        if buffer_name not in mod._buffers:

            raise AttributeError("`" + buffer_name + "` is not a buffer")

        return buffer

get_extra_state

def get_extra_state(
    self
) -> Any

Return any extra state to include in the module's state_dict.

Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns:

Type Description
object Any extra state to store in the module's state_dict
View Source
    def get_extra_state(self) -> Any:

        """Return any extra state to include in the module's state_dict.

        Implement this and a corresponding :func:`set_extra_state` for your module

        if you need to store extra state. This function is called when building the

        module's `state_dict()`.

        Note that extra state should be picklable to ensure working serialization

        of the state_dict. We only provide provide backwards compatibility guarantees

        for serializing Tensors; other objects may break backwards compatibility if

        their serialized pickled form changes.

        Returns:

            object: Any extra state to store in the module's state_dict

        """

        raise RuntimeError(

            "Reached a code path in Module.get_extra_state() that should never be called. "

            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

            "to report this bug.")

get_parameter

def get_parameter(
    self,
    target: str
) -> 'Parameter'

Return the parameter given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the Parameter
to look for. (See get_submodule for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.nn.Parameter The Parameter referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not an
nn.Parameter
View Source
    def get_parameter(self, target: str) -> "Parameter":

        """Return the parameter given by ``target`` if it exists, otherwise throw an error.

        See the docstring for ``get_submodule`` for a more detailed

        explanation of this method's functionality as well as how to

        correctly specify ``target``.

        Args:

            target: The fully-qualified string name of the Parameter

                to look for. (See ``get_submodule`` for how to specify a

                fully-qualified string.)

        Returns:

            torch.nn.Parameter: The Parameter referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not an

                ``nn.Parameter``

        """

        module_path, _, param_name = target.rpartition(".")

        mod: torch.nn.Module = self.get_submodule(module_path)

        if not hasattr(mod, param_name):

            raise AttributeError(mod._get_name() + " has no attribute `"

                                 + param_name + "`")

        param: torch.nn.Parameter = getattr(mod, param_name)

        if not isinstance(param, torch.nn.Parameter):

            raise AttributeError("`" + param_name + "` is not an "

                                 "nn.Parameter")

        return param

get_submodule

def get_submodule(
    self,
    target: str
) -> 'Module'

Return the submodule given by target if it exists, otherwise throw an error.

For example, let's say you have an nn.Module A that looks like this:

.. code-block:: text

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(The diagram shows an nn.Module A. A has a nested submodule net_b, which itself has two submodules net_c and linear. net_c then has a submodule conv.)

To check whether or not we have the linear submodule, we would call get_submodule("net_b.linear"). To check whether we have the conv submodule, we would call get_submodule("net_b.net_c.conv").

The runtime of get_submodule is bounded by the degree of module nesting in target. A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.nn.Module The submodule referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not an
nn.Module
View Source
    def get_submodule(self, target: str) -> "Module":

        """Return the submodule given by ``target`` if it exists, otherwise throw an error.

        For example, let's say you have an ``nn.Module`` ``A`` that

        looks like this:

        .. code-block:: text

            A(

                (net_b): Module(

                    (net_c): Module(

                        (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))

                    )

                    (linear): Linear(in_features=100, out_features=200, bias=True)

                )

            )

        (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested

        submodule ``net_b``, which itself has two submodules ``net_c``

        and ``linear``. ``net_c`` then has a submodule ``conv``.)

        To check whether or not we have the ``linear`` submodule, we

        would call ``get_submodule("net_b.linear")``. To check whether

        we have the ``conv`` submodule, we would call

        ``get_submodule("net_b.net_c.conv")``.

        The runtime of ``get_submodule`` is bounded by the degree

        of module nesting in ``target``. A query against

        ``named_modules`` achieves the same result, but it is O(N) in

        the number of transitive modules. So, for a simple check to see

        if some submodule exists, ``get_submodule`` should always be

        used.

        Args:

            target: The fully-qualified string name of the submodule

                to look for. (See above example for how to specify a

                fully-qualified string.)

        Returns:

            torch.nn.Module: The submodule referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not an

                ``nn.Module``

        """

        if target == "":

            return self

        atoms: List[str] = target.split(".")

        mod: torch.nn.Module = self

        for item in atoms:

            if not hasattr(mod, item):

                raise AttributeError(mod._get_name() + " has no "

                                     "attribute `" + item + "`")

            mod = getattr(mod, item)

            if not isinstance(mod, torch.nn.Module):

                raise AttributeError("`" + item + "` is not "

                                     "an nn.Module")

        return mod

half

def half(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to half datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def half(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``half`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.half() if t.is_floating_point() else t)

ipu

def ipu(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the IPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the IPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on IPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Arguments:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.ipu(device))

load_state_dict

def load_state_dict(
    self,
    state_dict: Mapping[str, Any],
    strict: bool = True,
    assign: bool = False
)

Copy parameters and buffers from :attr:state_dict into this module and its descendants.

If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.

.. warning:: If :attr:assign is True the optimizer must be created after the call to :attr:load_state_dict unless :func:~torch.__future__.get_swap_module_params_on_conversion is True.

Parameters:

Name Type Description Default
state_dict dict a dict containing parameters and
persistent buffers.
None
strict bool whether to strictly enforce that the keys
in :attr:state_dict match the keys returned by this module's
:meth:~torch.nn.Module.state_dict function. Default: True
None
assign bool When False, the properties of the tensors
in the current module are preserved while when True, the
properties of the Tensors in the state dict are preserved. The only
exception is the requires_grad field of :class:~torch.nn.Parameters
for which the value from the module is preserved.
Default: False
None

Returns:

Type Description
None NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
View Source
    def load_state_dict(self, state_dict: Mapping[str, Any],

                        strict: bool = True, assign: bool = False):

        r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

        If :attr:`strict` is ``True``, then

        the keys of :attr:`state_dict` must exactly match the keys returned

        by this module's :meth:`~torch.nn.Module.state_dict` function.

        .. warning::

            If :attr:`assign` is ``True`` the optimizer must be created after

            the call to :attr:`load_state_dict` unless

            :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.

        Args:

            state_dict (dict): a dict containing parameters and

                persistent buffers.

            strict (bool, optional): whether to strictly enforce that the keys

                in :attr:`state_dict` match the keys returned by this module's

                :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

            assign (bool, optional): When ``False``, the properties of the tensors

                in the current module are preserved while when ``True``, the

                properties of the Tensors in the state dict are preserved. The only

                exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s

                for which the value from the module is preserved.

                Default: ``False``

        Returns:

            ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:

                * **missing_keys** is a list of str containing the missing keys

                * **unexpected_keys** is a list of str containing the unexpected keys

        Note:

            If a parameter or buffer is registered as ``None`` and its corresponding key

            exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a

            ``RuntimeError``.

        """

        if not isinstance(state_dict, Mapping):

            raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")

        missing_keys: List[str] = []

        unexpected_keys: List[str] = []

        error_msgs: List[str] = []

        # copy state_dict so _load_from_state_dict can modify it

        metadata = getattr(state_dict, '_metadata', None)

        state_dict = OrderedDict(state_dict)

        if metadata is not None:

            # mypy isn't aware that "_metadata" exists in state_dict

            state_dict._metadata = metadata  # type: ignore[attr-defined]

        def load(module, local_state_dict, prefix=''):

            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})

            if assign:

                local_metadata['assign_to_params_buffers'] = assign

            module._load_from_state_dict(

                local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)

            for name, child in module._modules.items():

                if child is not None:

                    child_prefix = prefix + name + '.'

                    child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}

                    load(child, child_state_dict, child_prefix)  # noqa: F821

            # Note that the hook can modify missing_keys and unexpected_keys.

            incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

            for hook in module._load_state_dict_post_hooks.values():

                out = hook(module, incompatible_keys)

                assert out is None, (

                    "Hooks registered with ``register_load_state_dict_post_hook`` are not"

                    "expected to return new values, if incompatible_keys need to be modified,"

                    "it should be done inplace."

                )

        load(self, state_dict)

        del load

        if strict:

            if len(unexpected_keys) > 0:

                error_msgs.insert(

                    0, 'Unexpected key(s) in state_dict: {}. '.format(

                        ', '.join(f'"{k}"' for k in unexpected_keys)))

            if len(missing_keys) > 0:

                error_msgs.insert(

                    0, 'Missing key(s) in state_dict: {}. '.format(

                        ', '.join(f'"{k}"' for k in missing_keys)))

        if len(error_msgs) > 0:

            raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(

                               self.__class__.__name__, "\n\t".join(error_msgs)))

        return _IncompatibleKeys(missing_keys, unexpected_keys)

modules

def modules(
    self
) -> Iterator[ForwardRef('Module')]

Return an iterator over all modules in the network.

Yields:

Type Description
Module a module in the network
View Source
    def modules(self) -> Iterator['Module']:

        r"""Return an iterator over all modules in the network.

        Yields:

            Module: a module in the network

        Note:

            Duplicate modules are returned only once. In the following

            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)

            >>> net = nn.Sequential(l, l)

            >>> for idx, m in enumerate(net.modules()):

            ...     print(idx, '->', m)

            0 -> Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            )

            1 -> Linear(in_features=2, out_features=2, bias=True)

        """

        for _, module in self.named_modules():

            yield module

named_buffers

def named_buffers(
    self,
    prefix: str = '',
    recurse: bool = True,
    remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:

Name Type Description Default
prefix str prefix to prepend to all buffer names. None
recurse bool if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module. Defaults to True.
None
remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True

Yields:

Type Description
None (str, torch.Tensor): Tuple containing the name and buffer
View Source
    def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:

        r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

        Args:

            prefix (str): prefix to prepend to all buffer names.

            recurse (bool, optional): if True, then yields buffers of this module

                and all submodules. Otherwise, yields only buffers that

                are direct members of this module. Defaults to True.

            remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.

        Yields:

            (str, torch.Tensor): Tuple containing the name and buffer

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, buf in self.named_buffers():

            >>>     if name in ['running_var']:

            >>>         print(buf.size())

        """

        gen = self._named_members(

            lambda module: module._buffers.items(),

            prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)

        yield from gen

named_children

def named_children(
    self
) -> Iterator[Tuple[str, ForwardRef('Module')]]

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields:

Type Description
None (str, Module): Tuple containing a name and child module
View Source
    def named_children(self) -> Iterator[Tuple[str, 'Module']]:

        r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

        Yields:

            (str, Module): Tuple containing a name and child module

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, module in model.named_children():

            >>>     if name in ['conv4', 'conv5']:

            >>>         print(module)

        """

        memo = set()

        for name, module in self._modules.items():

            if module is not None and module not in memo:

                memo.add(module)

                yield name, module

named_modules

def named_modules(
    self,
    memo: Optional[Set[ForwardRef('Module')]] = None,
    prefix: str = '',
    remove_duplicate: bool = True
)

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Parameters:

Name Type Description Default
memo None a memo to store the set of modules already added to the result None
prefix None a prefix that will be added to the name of the module None
remove_duplicate None whether to remove the duplicated module instances in the result
or not
None

Yields:

Type Description
None (str, Module): Tuple of name and module
View Source
    def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):

        r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

        Args:

            memo: a memo to store the set of modules already added to the result

            prefix: a prefix that will be added to the name of the module

            remove_duplicate: whether to remove the duplicated module instances in the result

                or not

        Yields:

            (str, Module): Tuple of name and module

        Note:

            Duplicate modules are returned only once. In the following

            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)

            >>> net = nn.Sequential(l, l)

            >>> for idx, m in enumerate(net.named_modules()):

            ...     print(idx, '->', m)

            0 -> ('', Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            ))

            1 -> ('0', Linear(in_features=2, out_features=2, bias=True))

        """

        if memo is None:

            memo = set()

        if self not in memo:

            if remove_duplicate:

                memo.add(self)

            yield prefix, self

            for name, module in self._modules.items():

                if module is None:

                    continue

                submodule_prefix = prefix + ('.' if prefix else '') + name

                yield from module.named_modules(memo, submodule_prefix, remove_duplicate)

named_parameters

def named_parameters(
    self,
    prefix: str = '',
    recurse: bool = True,
    remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:

Name Type Description Default
prefix str prefix to prepend to all parameter names. None
recurse bool if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
None
remove_duplicate bool whether to remove the duplicated
parameters in the result. Defaults to True.
None

Yields:

Type Description
None (str, Parameter): Tuple containing the name and parameter
View Source
    def named_parameters(

            self,

            prefix: str = '',

            recurse: bool = True,

            remove_duplicate: bool = True

    ) -> Iterator[Tuple[str, Parameter]]:

        r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

        Args:

            prefix (str): prefix to prepend to all parameter names.

            recurse (bool): if True, then yields parameters of this module

                and all submodules. Otherwise, yields only parameters that

                are direct members of this module.

            remove_duplicate (bool, optional): whether to remove the duplicated

                parameters in the result. Defaults to True.

        Yields:

            (str, Parameter): Tuple containing the name and parameter

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, param in self.named_parameters():

            >>>     if name in ['bias']:

            >>>         print(param.size())

        """

        gen = self._named_members(

            lambda module: module._parameters.items(),

            prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)

        yield from gen

parameters

def parameters(
    self,
    recurse: bool = True
) -> Iterator[torch.nn.parameter.Parameter]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Parameters:

Name Type Description Default
recurse bool if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
None

Yields:

Type Description
Parameter module parameter
View Source
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:

        r"""Return an iterator over module parameters.

        This is typically passed to an optimizer.

        Args:

            recurse (bool): if True, then yields parameters of this module

                and all submodules. Otherwise, yields only parameters that

                are direct members of this module.

        Yields:

            Parameter: module parameter

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for param in model.parameters():

            >>>     print(type(param), param.size())

            <class 'torch.Tensor'> (20L,)

            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

        """

        for name, param in self.named_parameters(recurse=recurse):

            yield param

register_backward_hook

def register_backward_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]
) -> torch.utils.hooks.RemovableHandle

Register a backward hook on the module.

This function is deprecated in favor of :meth:~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions.

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_backward_hook(

        self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]]

    ) -> RemovableHandle:

        r"""Register a backward hook on the module.

        This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and

        the behavior of this function will change in future versions.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        if self._is_full_backward_hook is True:

            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "

                               "single Module. Please use only one of them.")

        self._is_full_backward_hook = False

        handle = hooks.RemovableHandle(self._backward_hooks)

        self._backward_hooks[handle.id] = hook

        return handle

register_buffer

def register_buffer(
    self,
    name: str,
    tensor: Optional[torch.Tensor],
    persistent: bool = True
) -> None

Add a buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:state_dict.

Buffers can be accessed as attributes using given names.

Parameters:

Name Type Description Default
name str name of the buffer. The buffer can be accessed
from this module using the given name
None
tensor Tensor or None buffer to be registered. If None, then operations
that run on buffers, such as :attr:cuda, are ignored. If None,
the buffer is not included in the module's :attr:state_dict.
None
persistent bool whether the buffer is part of this module's
:attr:state_dict.
None
View Source
    def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:

        r"""Add a buffer to the module.

        This is typically used to register a buffer that should not to be

        considered a model parameter. For example, BatchNorm's ``running_mean``

        is not a parameter, but is part of the module's state. Buffers, by

        default, are persistent and will be saved alongside parameters. This

        behavior can be changed by setting :attr:`persistent` to ``False``. The

        only difference between a persistent buffer and a non-persistent buffer

        is that the latter will not be a part of this module's

        :attr:`state_dict`.

        Buffers can be accessed as attributes using given names.

        Args:

            name (str): name of the buffer. The buffer can be accessed

                from this module using the given name

            tensor (Tensor or None): buffer to be registered. If ``None``, then operations

                that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,

                the buffer is **not** included in the module's :attr:`state_dict`.

            persistent (bool): whether the buffer is part of this module's

                :attr:`state_dict`.

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> self.register_buffer('running_mean', torch.zeros(num_features))

        """

        if persistent is False and isinstance(self, torch.jit.ScriptModule):

            raise RuntimeError("ScriptModule does not support non-persistent buffers")

        if '_buffers' not in self.__dict__:

            raise AttributeError(

                "cannot assign buffer before Module.__init__() call")

        elif not isinstance(name, str):

            raise TypeError(f"buffer name should be a string. Got {torch.typename(name)}")

        elif '.' in name:

            raise KeyError("buffer name can't contain \".\"")

        elif name == '':

            raise KeyError("buffer name can't be empty string \"\"")

        elif hasattr(self, name) and name not in self._buffers:

            raise KeyError(f"attribute '{name}' already exists")

        elif tensor is not None and not isinstance(tensor, torch.Tensor):

            raise TypeError(f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' "

                            "(torch Tensor or None required)"

                            )

        else:

            for hook in _global_buffer_registration_hooks.values():

                output = hook(self, name, tensor)

                if output is not None:

                    tensor = output

            self._buffers[name] = tensor

            if persistent:

                self._non_persistent_buffers_set.discard(name)

            else:

                self._non_persistent_buffers_set.add(name)

register_forward_hook

def register_forward_hook(
    self,
    hook: Union[Callable[[~T, Tuple[Any, ...], Any], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]],
    *,
    prepend: bool = False,
    with_kwargs: bool = False,
    always_call: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a forward hook on the module.

The hook will be called every time after :func:forward has computed an output.

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:forward is called. The hook should have the following signature::

hook(module, args, output) -> None or modified output

If with_kwargs is True, the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature::

hook(module, args, kwargs, output) -> None or modified output

Parameters:

Name Type Description Default
hook Callable The user defined hook to be registered. None
prepend bool If True, the provided hook will be fired
before all existing forward hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this :class:torch.nn.modules.Module. Note that global
forward hooks registered with
:func:register_module_forward_hook will fire before all hooks
registered by this method.
Default: False
None
with_kwargs bool If True, the hook will be passed the
kwargs given to the forward function.
Default: False
None
always_call bool If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_forward_hook(

        self,

        hook: Union[

            Callable[[T, Tuple[Any, ...], Any], Optional[Any]],

            Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],

        ],

        *,

        prepend: bool = False,

        with_kwargs: bool = False,

        always_call: bool = False,

    ) -> RemovableHandle:

        r"""Register a forward hook on the module.

        The hook will be called every time after :func:`forward` has computed an output.

        If ``with_kwargs`` is ``False`` or not specified, the input contains only

        the positional arguments given to the module. Keyword arguments won't be

        passed to the hooks and only to the ``forward``. The hook can modify the

        output. It can modify the input inplace but it will not have effect on

        forward since this is called after :func:`forward` is called. The hook

        should have the following signature::

            hook(module, args, output) -> None or modified output

        If ``with_kwargs`` is ``True``, the forward hook will be passed the

        ``kwargs`` given to the forward function and be expected to return the

        output possibly modified. The hook should have the following signature::

            hook(module, args, kwargs, output) -> None or modified output

        Args:

            hook (Callable): The user defined hook to be registered.

            prepend (bool): If ``True``, the provided ``hook`` will be fired

                before all existing ``forward`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``forward`` hooks on

                this :class:`torch.nn.modules.Module`. Note that global

                ``forward`` hooks registered with

                :func:`register_module_forward_hook` will fire before all hooks

                registered by this method.

                Default: ``False``

            with_kwargs (bool): If ``True``, the ``hook`` will be passed the

                kwargs given to the forward function.

                Default: ``False``

            always_call (bool): If ``True`` the ``hook`` will be run regardless of

                whether an exception is raised while calling the Module.

                Default: ``False``

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(

            self._forward_hooks,

            extra_dict=[self._forward_hooks_with_kwargs, self._forward_hooks_always_called],

        )

        self._forward_hooks[handle.id] = hook

        if with_kwargs:

            self._forward_hooks_with_kwargs[handle.id] = True

        if always_call:

            self._forward_hooks_always_called[handle.id] = True

        if prepend:

            self._forward_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_forward_pre_hook

def register_forward_pre_hook(
    self,
    hook: Union[Callable[[~T, Tuple[Any, ...]], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]],
    *,
    prepend: bool = False,
    with_kwargs: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a forward pre-hook on the module.

The hook will be called every time before :func:forward is invoked.

If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature::

hook(module, args) -> None or modified input

If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature::

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs

Parameters:

Name Type Description Default
hook Callable The user defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing forward_pre hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this :class:torch.nn.modules.Module. Note that global
forward_pre hooks registered with
:func:register_module_forward_pre_hook will fire before all
hooks registered by this method.
Default: False
None
with_kwargs bool If true, the hook will be passed the kwargs
given to the forward function.
Default: False
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_forward_pre_hook(

        self,

        hook: Union[

            Callable[[T, Tuple[Any, ...]], Optional[Any]],

            Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]],

        ],

        *,

        prepend: bool = False,

        with_kwargs: bool = False,

    ) -> RemovableHandle:

        r"""Register a forward pre-hook on the module.

        The hook will be called every time before :func:`forward` is invoked.

        If ``with_kwargs`` is false or not specified, the input contains only

        the positional arguments given to the module. Keyword arguments won't be

        passed to the hooks and only to the ``forward``. The hook can modify the

        input. User can either return a tuple or a single modified value in the

        hook. We will wrap the value into a tuple if a single value is returned

        (unless that value is already a tuple). The hook should have the

        following signature::

            hook(module, args) -> None or modified input

        If ``with_kwargs`` is true, the forward pre-hook will be passed the

        kwargs given to the forward function. And if the hook modifies the

        input, both the args and kwargs should be returned. The hook should have

        the following signature::

            hook(module, args, kwargs) -> None or a tuple of modified input and kwargs

        Args:

            hook (Callable): The user defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``forward_pre`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``forward_pre`` hooks

                on this :class:`torch.nn.modules.Module`. Note that global

                ``forward_pre`` hooks registered with

                :func:`register_module_forward_pre_hook` will fire before all

                hooks registered by this method.

                Default: ``False``

            with_kwargs (bool): If true, the ``hook`` will be passed the kwargs

                given to the forward function.

                Default: ``False``

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(

            self._forward_pre_hooks,

            extra_dict=self._forward_pre_hooks_with_kwargs

        )

        self._forward_pre_hooks[handle.id] = hook

        if with_kwargs:

            self._forward_pre_hooks_with_kwargs[handle.id] = True

        if prepend:

            self._forward_pre_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_full_backward_hook

def register_full_backward_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]],
    prepend: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a backward hook on the module.

The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature::

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The :attr:grad_input and :attr:grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:grad_input in subsequent computations. :attr:grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:grad_input and :attr:grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.

.. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.

Parameters:

Name Type Description Default
hook Callable The user-defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing backward hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this :class:torch.nn.modules.Module. Note that global
backward hooks registered with
:func:register_module_full_backward_hook will fire before
all hooks registered by this method.
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_full_backward_hook(

        self,

        hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]],

        prepend: bool = False,

    ) -> RemovableHandle:

        r"""Register a backward hook on the module.

        The hook will be called every time the gradients with respect to a module

        are computed, i.e. the hook will execute if and only if the gradients with

        respect to module outputs are computed. The hook should have the following

        signature::

            hook(module, grad_input, grad_output) -> tuple(Tensor) or None

        The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients

        with respect to the inputs and outputs respectively. The hook should

        not modify its arguments, but it can optionally return a new gradient with

        respect to the input that will be used in place of :attr:`grad_input` in

        subsequent computations. :attr:`grad_input` will only correspond to the inputs given

        as positional arguments and all kwarg arguments are ignored. Entries

        in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor

        arguments.

        For technical reasons, when this hook is applied to a Module, its forward function will

        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view

        of each Tensor returned by the Module's forward function.

        .. warning ::

            Modifying inputs or outputs inplace is not allowed when using backward hooks and

            will raise an error.

        Args:

            hook (Callable): The user-defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``backward`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``backward`` hooks on

                this :class:`torch.nn.modules.Module`. Note that global

                ``backward`` hooks registered with

                :func:`register_module_full_backward_hook` will fire before

                all hooks registered by this method.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        if self._is_full_backward_hook is False:

            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "

                               "single Module. Please use only one of them.")

        self._is_full_backward_hook = True

        handle = hooks.RemovableHandle(self._backward_hooks)

        self._backward_hooks[handle.id] = hook

        if prepend:

            self._backward_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_full_backward_pre_hook

def register_full_backward_pre_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]],
    prepend: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a backward pre-hook on the module.

The hook will be called every time the gradients for the module are computed. The hook should have the following signature::

hook(module, grad_output) -> tuple[Tensor] or None

The :attr:grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:grad_output in subsequent computations. Entries in :attr:grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.

.. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error.

Parameters:

Name Type Description Default
hook Callable The user-defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing backward_pre hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this :class:torch.nn.modules.Module. Note that global
backward_pre hooks registered with
:func:register_module_full_backward_pre_hook will fire before
all hooks registered by this method.
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_full_backward_pre_hook(

        self,

        hook: Callable[["Module", _grad_t], Union[None, _grad_t]],

        prepend: bool = False,

    ) -> RemovableHandle:

        r"""Register a backward pre-hook on the module.

        The hook will be called every time the gradients for the module are computed.

        The hook should have the following signature::

            hook(module, grad_output) -> tuple[Tensor] or None

        The :attr:`grad_output` is a tuple. The hook should

        not modify its arguments, but it can optionally return a new gradient with

        respect to the output that will be used in place of :attr:`grad_output` in

        subsequent computations. Entries in :attr:`grad_output` will be ``None`` for

        all non-Tensor arguments.

        For technical reasons, when this hook is applied to a Module, its forward function will

        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view

        of each Tensor returned by the Module's forward function.

        .. warning ::

            Modifying inputs inplace is not allowed when using backward hooks and

            will raise an error.

        Args:

            hook (Callable): The user-defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``backward_pre`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``backward_pre`` hooks

                on this :class:`torch.nn.modules.Module`. Note that global

                ``backward_pre`` hooks registered with

                :func:`register_module_full_backward_pre_hook` will fire before

                all hooks registered by this method.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(self._backward_pre_hooks)

        self._backward_pre_hooks[handle.id] = hook

        if prepend:

            self._backward_pre_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_load_state_dict_post_hook

def register_load_state_dict_post_hook(
    self,
    hook
)

Register a post hook to be run after module's load_state_dict is called.

It should have the following signature:: hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys.

The given incompatible_keys can be modified inplace if needed.

Note that the checks performed when calling :func:load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error.

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_load_state_dict_post_hook(self, hook):

        r"""Register a post hook to be run after module's ``load_state_dict`` is called.

        It should have the following signature::

            hook(module, incompatible_keys) -> None

        The ``module`` argument is the current module that this hook is registered

        on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting

        of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``

        is a ``list`` of ``str`` containing the missing keys and

        ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.

        The given incompatible_keys can be modified inplace if needed.

        Note that the checks performed when calling :func:`load_state_dict` with

        ``strict=True`` are affected by modifications the hook makes to

        ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either

        set of keys will result in an error being thrown when ``strict=True``, and

        clearing out both missing and unexpected keys will avoid an error.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(self._load_state_dict_post_hooks)

        self._load_state_dict_post_hooks[handle.id] = hook

        return handle

register_module

def register_module(
    self,
    name: str,
    module: Optional[ForwardRef('Module')]
) -> None

Alias for :func:add_module.

View Source
    def register_module(self, name: str, module: Optional['Module']) -> None:

        r"""Alias for :func:`add_module`."""

        self.add_module(name, module)

register_parameter

def register_parameter(
    self,
    name: str,
    param: Optional[torch.nn.parameter.Parameter]
) -> None

Add a parameter to the module.

The parameter can be accessed as an attribute using given name.

Parameters:

Name Type Description Default
name str name of the parameter. The parameter can be accessed
from this module using the given name
None
param Parameter or None parameter to be added to the module. If
None, then operations that run on parameters, such as :attr:cuda,
are ignored. If None, the parameter is not included in the
module's :attr:state_dict.
None
View Source
    def register_parameter(self, name: str, param: Optional[Parameter]) -> None:

        r"""Add a parameter to the module.

        The parameter can be accessed as an attribute using given name.

        Args:

            name (str): name of the parameter. The parameter can be accessed

                from this module using the given name

            param (Parameter or None): parameter to be added to the module. If

                ``None``, then operations that run on parameters, such as :attr:`cuda`,

                are ignored. If ``None``, the parameter is **not** included in the

                module's :attr:`state_dict`.

        """

        if '_parameters' not in self.__dict__:

            raise AttributeError(

                "cannot assign parameter before Module.__init__() call")

        elif not isinstance(name, str):

            raise TypeError(f"parameter name should be a string. Got {torch.typename(name)}")

        elif '.' in name:

            raise KeyError("parameter name can't contain \".\"")

        elif name == '':

            raise KeyError("parameter name can't be empty string \"\"")

        elif hasattr(self, name) and name not in self._parameters:

            raise KeyError(f"attribute '{name}' already exists")

        if param is None:

            self._parameters[name] = None

        elif not isinstance(param, Parameter):

            raise TypeError(f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "

                            "(torch.nn.Parameter or None required)"

                            )

        elif param.grad_fn:

            raise ValueError(

                f"Cannot assign non-leaf Tensor to parameter '{name}'. Model "

                f"parameters must be created explicitly. To express '{name}' "

                "as a function of another Tensor, compute the value in "

                "the forward() method.")

        else:

            for hook in _global_parameter_registration_hooks.values():

                output = hook(self, name, param)

                if output is not None:

                    param = output

            self._parameters[name] = param

register_state_dict_pre_hook

def register_state_dict_pre_hook(
    self,
    hook
)

Register a pre-hook for the :meth:~torch.nn.Module.state_dict method.

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.

View Source
    def register_state_dict_pre_hook(self, hook):

        r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.

        These hooks will be called with arguments: ``self``, ``prefix``,

        and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered

        hooks can be used to perform pre-processing before the ``state_dict``

        call is made.

        """

        handle = hooks.RemovableHandle(self._state_dict_pre_hooks)

        self._state_dict_pre_hooks[handle.id] = hook

        return handle

requires_grad_

def requires_grad_(
    self: ~T,
    requires_grad: bool = True
) -> ~T

Change if autograd should record operations on parameters in this module.

This method sets the parameters' :attr:requires_grad attributes in-place.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).

See :ref:locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it.

Parameters:

Name Type Description Default
requires_grad bool whether autograd should record operations on
parameters in this module. Default: True.
None

Returns:

Type Description
Module self
View Source
    def requires_grad_(self: T, requires_grad: bool = True) -> T:

        r"""Change if autograd should record operations on parameters in this module.

        This method sets the parameters' :attr:`requires_grad` attributes

        in-place.

        This method is helpful for freezing part of the module for finetuning

        or training parts of a model individually (e.g., GAN training).

        See :ref:`locally-disable-grad-doc` for a comparison between

        `.requires_grad_()` and several similar mechanisms that may be confused with it.

        Args:

            requires_grad (bool): whether autograd should record operations on

                                  parameters in this module. Default: ``True``.

        Returns:

            Module: self

        """

        for p in self.parameters():

            p.requires_grad_(requires_grad)

        return self

set_extra_state

def set_extra_state(
    self,
    state: Any
) -> None

Set extra state contained in the loaded state_dict.

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding

View Source
    def set_extra_state(self, state: Any) -> None:

        """Set extra state contained in the loaded `state_dict`.

        This function is called from :func:`load_state_dict` to handle any extra state

        found within the `state_dict`. Implement this function and a corresponding

        :func:`get_extra_state` for your module if you need to store extra state within its

        `state_dict`.

        Args:

            state (dict): Extra state from the `state_dict`

        """

        raise RuntimeError(

            "Reached a code path in Module.set_extra_state() that should never be called. "

            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

            "to report this bug.")

share_memory

def share_memory(
    self: ~T
) -> ~T

See :meth:torch.Tensor.share_memory_.

View Source
    def share_memory(self: T) -> T:

        r"""See :meth:`torch.Tensor.share_memory_`."""

        return self._apply(lambda t: t.share_memory_())

state_dict

def state_dict(
    self,
    *args,
    destination=None,
    prefix='',
    keep_vars=False
)

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

.. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers.

.. warning:: Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

.. warning:: Please avoid the use of argument destination as it is not designed for end-users.

Parameters:

Name Type Description Default
destination dict If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
None
prefix str a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
None
keep_vars bool by default the :class:~torch.Tensor s
returned in the state dict are detached from autograd. If it's
set to True, detaching will not be performed.
Default: False.
None

Returns:

Type Description
dict a dictionary containing a whole state of the module
View Source
    def state_dict(self, *args, destination=None, prefix='', keep_vars=False):

        r"""Return a dictionary containing references to the whole state of the module.

        Both parameters and persistent buffers (e.g. running averages) are

        included. Keys are corresponding parameter and buffer names.

        Parameters and buffers set to ``None`` are not included.

        .. note::

            The returned object is a shallow copy. It contains references

            to the module's parameters and buffers.

        .. warning::

            Currently ``state_dict()`` also accepts positional arguments for

            ``destination``, ``prefix`` and ``keep_vars`` in order. However,

            this is being deprecated and keyword arguments will be enforced in

            future releases.

        .. warning::

            Please avoid the use of argument ``destination`` as it is not

            designed for end-users.

        Args:

            destination (dict, optional): If provided, the state of module will

                be updated into the dict and the same object is returned.

                Otherwise, an ``OrderedDict`` will be created and returned.

                Default: ``None``.

            prefix (str, optional): a prefix added to parameter and buffer

                names to compose the keys in state_dict. Default: ``''``.

            keep_vars (bool, optional): by default the :class:`~torch.Tensor` s

                returned in the state dict are detached from autograd. If it's

                set to ``True``, detaching will not be performed.

                Default: ``False``.

        Returns:

            dict:

                a dictionary containing a whole state of the module

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> module.state_dict().keys()

            ['bias', 'weight']

        """

        # TODO: Remove `args` and the parsing logic when BC allows.

        if len(args) > 0:

            if destination is None:

                destination = args[0]

            if len(args) > 1 and prefix == '':

                prefix = args[1]

            if len(args) > 2 and keep_vars is False:

                keep_vars = args[2]

            # DeprecationWarning is ignored by default

            warnings.warn(

                "Positional args are being deprecated, use kwargs instead. Refer to "

                "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"

                " for details.")

        if destination is None:

            destination = OrderedDict()

            destination._metadata = OrderedDict()

        local_metadata = dict(version=self._version)

        if hasattr(destination, "_metadata"):

            destination._metadata[prefix[:-1]] = local_metadata

        for hook in self._state_dict_pre_hooks.values():

            hook(self, prefix, keep_vars)

        self._save_to_state_dict(destination, prefix, keep_vars)

        for name, module in self._modules.items():

            if module is not None:

                module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)

        for hook in self._state_dict_hooks.values():

            hook_result = hook(self, destination, prefix, local_metadata)

            if hook_result is not None:

                destination = hook_result

        return destination

to

def to(
    self,
    *args,
    **kwargs
)

Move and/or cast the parameters and buffers.

This can be called as

.. function:: to(device=None, dtype=None, non_blocking=False) :noindex:

.. function:: to(dtype, non_blocking=False) :noindex:

.. function:: to(tensor, non_blocking=False) :noindex:

.. function:: to(memory_format=torch.channels_last) :noindex:

Its signature is similar to :meth:torch.Tensor.to, but only accepts floating point or complex :attr:dtype\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:dtype (if given). The integral parameters and buffers will be moved :attr:device, if that is given, but with dtypes unchanged. When :attr:non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device ( None class:torch.device): the desired device of the parameters
and buffers in this module
None
dtype ( None class:torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
None
tensor torch.Tensor Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
None
memory_format ( None class:torch.memory_format): the desired memory
format for 4D parameters and buffers in this module (keyword
only argument)
None

Returns:

Type Description
Module self
View Source
    def to(self, *args, **kwargs):

        r"""Move and/or cast the parameters and buffers.

        This can be called as

        .. function:: to(device=None, dtype=None, non_blocking=False)

           :noindex:

        .. function:: to(dtype, non_blocking=False)

           :noindex:

        .. function:: to(tensor, non_blocking=False)

           :noindex:

        .. function:: to(memory_format=torch.channels_last)

           :noindex:

        Its signature is similar to :meth:`torch.Tensor.to`, but only accepts

        floating point or complex :attr:`dtype`\ s. In addition, this method will

        only cast the floating point or complex parameters and buffers to :attr:`dtype`

        (if given). The integral parameters and buffers will be moved

        :attr:`device`, if that is given, but with dtypes unchanged. When

        :attr:`non_blocking` is set, it tries to convert/move asynchronously

        with respect to the host if possible, e.g., moving CPU Tensors with

        pinned memory to CUDA devices.

        See below for examples.

        .. note::

            This method modifies the module in-place.

        Args:

            device (:class:`torch.device`): the desired device of the parameters

                and buffers in this module

            dtype (:class:`torch.dtype`): the desired floating point or complex dtype of

                the parameters and buffers in this module

            tensor (torch.Tensor): Tensor whose dtype and device are the desired

                dtype and device for all parameters and buffers in this module

            memory_format (:class:`torch.memory_format`): the desired memory

                format for 4D parameters and buffers in this module (keyword

                only argument)

        Returns:

            Module: self

        Examples::

            >>> # xdoctest: +IGNORE_WANT("non-deterministic")

            >>> linear = nn.Linear(2, 2)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1913, -0.3420],

                    [-0.5113, -0.2325]])

            >>> linear.to(torch.double)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1913, -0.3420],

                    [-0.5113, -0.2325]], dtype=torch.float64)

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)

            >>> gpu1 = torch.device("cuda:1")

            >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1914, -0.3420],

                    [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')

            >>> cpu = torch.device("cpu")

            >>> linear.to(cpu)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1914, -0.3420],

                    [-0.5112, -0.2324]], dtype=torch.float16)

            >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.3741+0.j,  0.2382+0.j],

                    [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)

            >>> linear(torch.ones(3, 2, dtype=torch.cdouble))

            tensor([[0.6122+0.j, 0.1150+0.j],

                    [0.6122+0.j, 0.1150+0.j],

                    [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)

        """

        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

        if dtype is not None:

            if not (dtype.is_floating_point or dtype.is_complex):

                raise TypeError('nn.Module.to only accepts floating point or complex '

                                f'dtypes, but got desired dtype={dtype}')

            if dtype.is_complex:

                warnings.warn(

                    "Complex modules are a new feature under active development whose design may change, "

                    "and some modules might not work as expected when using complex tensors as parameters or buffers. "

                    "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

                    "if a complex module does not work as expected.")

        def convert(t):

            try:

                if convert_to_format is not None and t.dim() in (4, 5):

                    return t.to(

                        device,

                        dtype if t.is_floating_point() or t.is_complex() else None,

                        non_blocking,

                        memory_format=convert_to_format,

                    )

                return t.to(

                    device,

                    dtype if t.is_floating_point() or t.is_complex() else None,

                    non_blocking,

                )

            except NotImplementedError as e:

                if str(e) == "Cannot copy out of meta tensor; no data!":

                    raise NotImplementedError(

                        f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "

                        f"when moving module from meta to a different device."

                    ) from None

                else:

                    raise

        return self._apply(convert)

to_empty

def to_empty(
    self: ~T,
    *,
    device: Union[int, str, torch.device, NoneType],
    recurse: bool = True
) -> ~T

Move the parameters and buffers to the specified device without copying storage.

Parameters:

Name Type Description Default
device ( None class:torch.device): The desired device of the parameters
and buffers in this module.
None
recurse bool Whether parameters and buffers of submodules should
be recursively moved to the specified device.
None

Returns:

Type Description
Module self
View Source
    def to_empty(self: T, *, device: Optional[DeviceLikeType], recurse: bool = True) -> T:

        r"""Move the parameters and buffers to the specified device without copying storage.

        Args:

            device (:class:`torch.device`): The desired device of the parameters

                and buffers in this module.

            recurse (bool): Whether parameters and buffers of submodules should

                be recursively moved to the specified device.

        Returns:

            Module: self

        """

        return self._apply(lambda t: torch.empty_like(t, device=device), recurse=recurse)

train

def train(
    self: ~T,
    mode: bool = True
) -> ~T

Set the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

Parameters:

Name Type Description Default
mode bool whether to set training mode (True) or evaluation
mode (False). Default: True.
None

Returns:

Type Description
Module self
View Source
    def train(self: T, mode: bool = True) -> T:

        r"""Set the module in training mode.

        This has any effect only on certain modules. See documentations of

        particular modules for details of their behaviors in training/evaluation

        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,

        etc.

        Args:

            mode (bool): whether to set training mode (``True``) or evaluation

                         mode (``False``). Default: ``True``.

        Returns:

            Module: self

        """

        if not isinstance(mode, bool):

            raise ValueError("training mode is expected to be boolean")

        self.training = mode

        for module in self.children():

            module.train(mode)

        return self

type

def type(
    self: ~T,
    dst_type: Union[torch.dtype, str]
) -> ~T

Casts all parameters and buffers to :attr:dst_type.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
dst_type type or string the desired type None

Returns:

Type Description
Module self
View Source
    def type(self: T, dst_type: Union[dtype, str]) -> T:

        r"""Casts all parameters and buffers to :attr:`dst_type`.

        .. note::

            This method modifies the module in-place.

        Args:

            dst_type (type or string): the desired type

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.type(dst_type))

xpu

def xpu(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the XPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the XPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on XPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Arguments:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.xpu(device))

zero_grad

def zero_grad(
    self,
    set_to_none: bool = True
) -> None

Reset gradients of all model parameters.

See similar function under :class:torch.optim.Optimizer for more context.

Parameters:

Name Type Description Default
set_to_none bool instead of setting to zero, set the grads to None.
See :meth:torch.optim.Optimizer.zero_grad for details.
None
View Source
    def zero_grad(self, set_to_none: bool = True) -> None:

        r"""Reset gradients of all model parameters.

        See similar function under :class:`torch.optim.Optimizer` for more context.

        Args:

            set_to_none (bool): instead of setting to zero, set the grads to None.

                See :meth:`torch.optim.Optimizer.zero_grad` for details.

        """

        if getattr(self, '_is_replica', False):

            warnings.warn(

                "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "

                "The parameters are copied (in a differentiable manner) from the original module. "

                "This means they are not leaf nodes in autograd and so don't accumulate gradients. "

                "If you need gradients in your forward method, consider using autograd.grad instead.")

        for p in self.parameters():

            if p.grad is not None:

                if set_to_none:

                    p.grad = None

                else:

                    if p.grad.grad_fn is not None:

                        p.grad.detach_()

                    else:

                        p.grad.requires_grad_(False)

                    p.grad.zero_()

MlpBlock

class MlpBlock(
    in_dim: int,
    dims: Sequence[int],
    nonlins: Sequence[Union[str, torch.nn.modules.module.Module]],
    batch_norm: bool = True
)

A general-purpose MLP.

Attributes

Name Type Description Default
in_dim None Input dimension. None
dims None Hidden dimensions, including output dimension. None
nonlins None Non-linearities to apply after each one of the hidden
dimensions.
Can be either a sequence of strings which are keys in the ACTIVATIONS
dict, or instances of nn.Module (e.g. an instance of nn.ReLU()).
Length should match 'dims'.
None
View Source
class MlpBlock(nn.Module):

    """

    A general-purpose MLP.

    Args:

        in_dim: Input dimension.

        dims: Hidden dimensions, including output dimension.

        nonlins: Non-linearities to apply after each one of the hidden

            dimensions.

            Can be either a sequence of strings which are keys in the ACTIVATIONS

            dict, or instances of nn.Module (e.g. an instance of nn.ReLU()).

            Length should match 'dims'.

    """

    def __init__(

        self,

        in_dim: int,

        dims: Sequence[int],

        nonlins: Sequence[Union[str, nn.Module]],

        batch_norm: bool = True,

    ):

        assert len(nonlins) == len(dims)

        self.in_dim = in_dim

        self.out_dim = dims[-1]

        self.dims = dims

        self.nonlins = nonlins

        super().__init__()

        layers = []

        for i, out_dim in enumerate(self.dims):

            layers.append(MLPLayer(in_dim, out_dim, nonlins[i], batch_norm))

            in_dim = out_dim

        self.sequence = nn.Sequential(*layers)

    def _make_activation(self, act: Union[str, nn.Module]) -> nn.Module:

        if isinstance(act, str):

            return ACTIVATIONS[act](**ACTIVATION_DEFAULT_KWARGS[act])

        return act

    def forward(self, x: Tensor) -> Tensor:

        """

        Args:

            x: An input tensor, of shape (N, D) containing N samples with D features.

        Returns:

            An output tensor of shape (N, D_out) where D_out is the output dim.

        """

        return self.sequence.forward(x.reshape(x.size(0), -1))

Ancestors (in MRO)

  • torch.nn.modules.module.Module

Class variables

T_destination
call_super_init
dump_patches

Methods

add_module

def add_module(
    self,
    name: str,
    module: Optional[ForwardRef('Module')]
) -> None

Add a child module to the current module.

The module can be accessed as an attribute using the given name.

Parameters:

Name Type Description Default
name str name of the child module. The child module can be
accessed from this module using the given name
None
module Module child module to be added to the module. None
View Source
    def add_module(self, name: str, module: Optional['Module']) -> None:

        r"""Add a child module to the current module.

        The module can be accessed as an attribute using the given name.

        Args:

            name (str): name of the child module. The child module can be

                accessed from this module using the given name

            module (Module): child module to be added to the module.

        """

        if not isinstance(module, Module) and module is not None:

            raise TypeError(f"{torch.typename(module)} is not a Module subclass")

        elif not isinstance(name, str):

            raise TypeError(f"module name should be a string. Got {torch.typename(name)}")

        elif hasattr(self, name) and name not in self._modules:

            raise KeyError(f"attribute '{name}' already exists")

        elif '.' in name:

            raise KeyError(f"module name can't contain \".\", got: {name}")

        elif name == '':

            raise KeyError("module name can't be empty string \"\"")

        for hook in _global_module_registration_hooks.values():

            output = hook(self, name, module)

            if output is not None:

                module = output

        self._modules[name] = module

apply

def apply(
    self: ~T,
    fn: Callable[[ForwardRef('Module')], NoneType]
) -> ~T

Apply fn recursively to every submodule (as returned by .children()) as well as self.

Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).

Parameters:

Name Type Description Default
fn ( None class:Module -> None): function to be applied to each submodule None

Returns:

Type Description
Module self
View Source
    def apply(self: T, fn: Callable[['Module'], None]) -> T:

        r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.

        Typical use includes initializing the parameters of a model

        (see also :ref:`nn-init-doc`).

        Args:

            fn (:class:`Module` -> None): function to be applied to each submodule

        Returns:

            Module: self

        Example::

            >>> @torch.no_grad()

            >>> def init_weights(m):

            >>>     print(m)

            >>>     if type(m) == nn.Linear:

            >>>         m.weight.fill_(1.0)

            >>>         print(m.weight)

            >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

            >>> net.apply(init_weights)

            Linear(in_features=2, out_features=2, bias=True)

            Parameter containing:

            tensor([[1., 1.],

                    [1., 1.]], requires_grad=True)

            Linear(in_features=2, out_features=2, bias=True)

            Parameter containing:

            tensor([[1., 1.],

                    [1., 1.]], requires_grad=True)

            Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            )

        """

        for module in self.children():

            module.apply(fn)

        fn(self)

        return self

bfloat16

def bfloat16(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def bfloat16(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)

buffers

def buffers(
    self,
    recurse: bool = True
) -> Iterator[torch.Tensor]

Return an iterator over module buffers.

Parameters:

Name Type Description Default
recurse bool if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
None

Yields:

Type Description
torch.Tensor module buffer
View Source
    def buffers(self, recurse: bool = True) -> Iterator[Tensor]:

        r"""Return an iterator over module buffers.

        Args:

            recurse (bool): if True, then yields buffers of this module

                and all submodules. Otherwise, yields only buffers that

                are direct members of this module.

        Yields:

            torch.Tensor: module buffer

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for buf in model.buffers():

            >>>     print(type(buf), buf.size())

            <class 'torch.Tensor'> (20L,)

            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

        """

        for _, buf in self.named_buffers(recurse=recurse):

            yield buf

children

def children(
    self
) -> Iterator[ForwardRef('Module')]

Return an iterator over immediate children modules.

Yields:

Type Description
Module a child module
View Source
    def children(self) -> Iterator['Module']:

        r"""Return an iterator over immediate children modules.

        Yields:

            Module: a child module

        """

        for name, module in self.named_children():

            yield module

compile

def compile(
    self,
    *args,
    **kwargs
)

Compile this Module's forward using :func:torch.compile.

This Module's __call__ method is compiled and all arguments are passed as-is to :func:torch.compile.

See :func:torch.compile for details on the arguments for this function.

View Source
    def compile(self, *args, **kwargs):

        """

        Compile this Module's forward using :func:`torch.compile`.

        This Module's `__call__` method is compiled and all arguments are passed as-is

        to :func:`torch.compile`.

        See :func:`torch.compile` for details on the arguments for this function.

        """

        self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs)

cpu

def cpu(
    self: ~T
) -> ~T

Move all model parameters and buffers to the CPU.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def cpu(self: T) -> T:

        r"""Move all model parameters and buffers to the CPU.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.cpu())

cuda

def cuda(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the GPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on GPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Args:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.cuda(device))

double

def double(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to double datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def double(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``double`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.double() if t.is_floating_point() else t)

eval

def eval(
    self: ~T
) -> ~T

Set the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

This is equivalent with :meth:self.train(False) <torch.nn.Module.train>.

See :ref:locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it.

Returns:

Type Description
Module self
View Source
    def eval(self: T) -> T:

        r"""Set the module in evaluation mode.

        This has any effect only on certain modules. See documentations of

        particular modules for details of their behaviors in training/evaluation

        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,

        etc.

        This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.

        See :ref:`locally-disable-grad-doc` for a comparison between

        `.eval()` and several similar mechanisms that may be confused with it.

        Returns:

            Module: self

        """

        return self.train(False)

extra_repr

def extra_repr(
    self
) -> str

Set the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

View Source
    def extra_repr(self) -> str:

        r"""Set the extra representation of the module.

        To print customized extra information, you should re-implement

        this method in your own modules. Both single-line and multi-line

        strings are acceptable.

        """

        return ''

float

def float(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to float datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def float(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``float`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.float() if t.is_floating_point() else t)

forward

def forward(
    self,
    x: torch.Tensor
) -> torch.Tensor

Parameters:

Name Type Description Default
x None An input tensor, of shape (N, D) containing N samples with D features. None

Returns:

Type Description
None An output tensor of shape (N, D_out) where D_out is the output dim.
View Source
    def forward(self, x: Tensor) -> Tensor:

        """

        Args:

            x: An input tensor, of shape (N, D) containing N samples with D features.

        Returns:

            An output tensor of shape (N, D_out) where D_out is the output dim.

        """

        return self.sequence.forward(x.reshape(x.size(0), -1))

get_buffer

def get_buffer(
    self,
    target: str
) -> 'Tensor'

Return the buffer given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the buffer
to look for. (See get_submodule for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.Tensor The buffer referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not a
buffer
View Source
    def get_buffer(self, target: str) -> "Tensor":

        """Return the buffer given by ``target`` if it exists, otherwise throw an error.

        See the docstring for ``get_submodule`` for a more detailed

        explanation of this method's functionality as well as how to

        correctly specify ``target``.

        Args:

            target: The fully-qualified string name of the buffer

                to look for. (See ``get_submodule`` for how to specify a

                fully-qualified string.)

        Returns:

            torch.Tensor: The buffer referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not a

                buffer

        """

        module_path, _, buffer_name = target.rpartition(".")

        mod: torch.nn.Module = self.get_submodule(module_path)

        if not hasattr(mod, buffer_name):

            raise AttributeError(mod._get_name() + " has no attribute `"

                                 + buffer_name + "`")

        buffer: torch.Tensor = getattr(mod, buffer_name)

        if buffer_name not in mod._buffers:

            raise AttributeError("`" + buffer_name + "` is not a buffer")

        return buffer

get_extra_state

def get_extra_state(
    self
) -> Any

Return any extra state to include in the module's state_dict.

Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns:

Type Description
object Any extra state to store in the module's state_dict
View Source
    def get_extra_state(self) -> Any:

        """Return any extra state to include in the module's state_dict.

        Implement this and a corresponding :func:`set_extra_state` for your module

        if you need to store extra state. This function is called when building the

        module's `state_dict()`.

        Note that extra state should be picklable to ensure working serialization

        of the state_dict. We only provide provide backwards compatibility guarantees

        for serializing Tensors; other objects may break backwards compatibility if

        their serialized pickled form changes.

        Returns:

            object: Any extra state to store in the module's state_dict

        """

        raise RuntimeError(

            "Reached a code path in Module.get_extra_state() that should never be called. "

            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

            "to report this bug.")

get_parameter

def get_parameter(
    self,
    target: str
) -> 'Parameter'

Return the parameter given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the Parameter
to look for. (See get_submodule for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.nn.Parameter The Parameter referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not an
nn.Parameter
View Source
    def get_parameter(self, target: str) -> "Parameter":

        """Return the parameter given by ``target`` if it exists, otherwise throw an error.

        See the docstring for ``get_submodule`` for a more detailed

        explanation of this method's functionality as well as how to

        correctly specify ``target``.

        Args:

            target: The fully-qualified string name of the Parameter

                to look for. (See ``get_submodule`` for how to specify a

                fully-qualified string.)

        Returns:

            torch.nn.Parameter: The Parameter referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not an

                ``nn.Parameter``

        """

        module_path, _, param_name = target.rpartition(".")

        mod: torch.nn.Module = self.get_submodule(module_path)

        if not hasattr(mod, param_name):

            raise AttributeError(mod._get_name() + " has no attribute `"

                                 + param_name + "`")

        param: torch.nn.Parameter = getattr(mod, param_name)

        if not isinstance(param, torch.nn.Parameter):

            raise AttributeError("`" + param_name + "` is not an "

                                 "nn.Parameter")

        return param

get_submodule

def get_submodule(
    self,
    target: str
) -> 'Module'

Return the submodule given by target if it exists, otherwise throw an error.

For example, let's say you have an nn.Module A that looks like this:

.. code-block:: text

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(The diagram shows an nn.Module A. A has a nested submodule net_b, which itself has two submodules net_c and linear. net_c then has a submodule conv.)

To check whether or not we have the linear submodule, we would call get_submodule("net_b.linear"). To check whether we have the conv submodule, we would call get_submodule("net_b.net_c.conv").

The runtime of get_submodule is bounded by the degree of module nesting in target. A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.nn.Module The submodule referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not an
nn.Module
View Source
    def get_submodule(self, target: str) -> "Module":

        """Return the submodule given by ``target`` if it exists, otherwise throw an error.

        For example, let's say you have an ``nn.Module`` ``A`` that

        looks like this:

        .. code-block:: text

            A(

                (net_b): Module(

                    (net_c): Module(

                        (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))

                    )

                    (linear): Linear(in_features=100, out_features=200, bias=True)

                )

            )

        (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested

        submodule ``net_b``, which itself has two submodules ``net_c``

        and ``linear``. ``net_c`` then has a submodule ``conv``.)

        To check whether or not we have the ``linear`` submodule, we

        would call ``get_submodule("net_b.linear")``. To check whether

        we have the ``conv`` submodule, we would call

        ``get_submodule("net_b.net_c.conv")``.

        The runtime of ``get_submodule`` is bounded by the degree

        of module nesting in ``target``. A query against

        ``named_modules`` achieves the same result, but it is O(N) in

        the number of transitive modules. So, for a simple check to see

        if some submodule exists, ``get_submodule`` should always be

        used.

        Args:

            target: The fully-qualified string name of the submodule

                to look for. (See above example for how to specify a

                fully-qualified string.)

        Returns:

            torch.nn.Module: The submodule referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not an

                ``nn.Module``

        """

        if target == "":

            return self

        atoms: List[str] = target.split(".")

        mod: torch.nn.Module = self

        for item in atoms:

            if not hasattr(mod, item):

                raise AttributeError(mod._get_name() + " has no "

                                     "attribute `" + item + "`")

            mod = getattr(mod, item)

            if not isinstance(mod, torch.nn.Module):

                raise AttributeError("`" + item + "` is not "

                                     "an nn.Module")

        return mod

half

def half(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to half datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def half(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``half`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.half() if t.is_floating_point() else t)

ipu

def ipu(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the IPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the IPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on IPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Arguments:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.ipu(device))

load_state_dict

def load_state_dict(
    self,
    state_dict: Mapping[str, Any],
    strict: bool = True,
    assign: bool = False
)

Copy parameters and buffers from :attr:state_dict into this module and its descendants.

If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.

.. warning:: If :attr:assign is True the optimizer must be created after the call to :attr:load_state_dict unless :func:~torch.__future__.get_swap_module_params_on_conversion is True.

Parameters:

Name Type Description Default
state_dict dict a dict containing parameters and
persistent buffers.
None
strict bool whether to strictly enforce that the keys
in :attr:state_dict match the keys returned by this module's
:meth:~torch.nn.Module.state_dict function. Default: True
None
assign bool When False, the properties of the tensors
in the current module are preserved while when True, the
properties of the Tensors in the state dict are preserved. The only
exception is the requires_grad field of :class:~torch.nn.Parameters
for which the value from the module is preserved.
Default: False
None

Returns:

Type Description
None NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
View Source
    def load_state_dict(self, state_dict: Mapping[str, Any],

                        strict: bool = True, assign: bool = False):

        r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

        If :attr:`strict` is ``True``, then

        the keys of :attr:`state_dict` must exactly match the keys returned

        by this module's :meth:`~torch.nn.Module.state_dict` function.

        .. warning::

            If :attr:`assign` is ``True`` the optimizer must be created after

            the call to :attr:`load_state_dict` unless

            :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.

        Args:

            state_dict (dict): a dict containing parameters and

                persistent buffers.

            strict (bool, optional): whether to strictly enforce that the keys

                in :attr:`state_dict` match the keys returned by this module's

                :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

            assign (bool, optional): When ``False``, the properties of the tensors

                in the current module are preserved while when ``True``, the

                properties of the Tensors in the state dict are preserved. The only

                exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s

                for which the value from the module is preserved.

                Default: ``False``

        Returns:

            ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:

                * **missing_keys** is a list of str containing the missing keys

                * **unexpected_keys** is a list of str containing the unexpected keys

        Note:

            If a parameter or buffer is registered as ``None`` and its corresponding key

            exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a

            ``RuntimeError``.

        """

        if not isinstance(state_dict, Mapping):

            raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")

        missing_keys: List[str] = []

        unexpected_keys: List[str] = []

        error_msgs: List[str] = []

        # copy state_dict so _load_from_state_dict can modify it

        metadata = getattr(state_dict, '_metadata', None)

        state_dict = OrderedDict(state_dict)

        if metadata is not None:

            # mypy isn't aware that "_metadata" exists in state_dict

            state_dict._metadata = metadata  # type: ignore[attr-defined]

        def load(module, local_state_dict, prefix=''):

            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})

            if assign:

                local_metadata['assign_to_params_buffers'] = assign

            module._load_from_state_dict(

                local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)

            for name, child in module._modules.items():

                if child is not None:

                    child_prefix = prefix + name + '.'

                    child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}

                    load(child, child_state_dict, child_prefix)  # noqa: F821

            # Note that the hook can modify missing_keys and unexpected_keys.

            incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

            for hook in module._load_state_dict_post_hooks.values():

                out = hook(module, incompatible_keys)

                assert out is None, (

                    "Hooks registered with ``register_load_state_dict_post_hook`` are not"

                    "expected to return new values, if incompatible_keys need to be modified,"

                    "it should be done inplace."

                )

        load(self, state_dict)

        del load

        if strict:

            if len(unexpected_keys) > 0:

                error_msgs.insert(

                    0, 'Unexpected key(s) in state_dict: {}. '.format(

                        ', '.join(f'"{k}"' for k in unexpected_keys)))

            if len(missing_keys) > 0:

                error_msgs.insert(

                    0, 'Missing key(s) in state_dict: {}. '.format(

                        ', '.join(f'"{k}"' for k in missing_keys)))

        if len(error_msgs) > 0:

            raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(

                               self.__class__.__name__, "\n\t".join(error_msgs)))

        return _IncompatibleKeys(missing_keys, unexpected_keys)

modules

def modules(
    self
) -> Iterator[ForwardRef('Module')]

Return an iterator over all modules in the network.

Yields:

Type Description
Module a module in the network
View Source
    def modules(self) -> Iterator['Module']:

        r"""Return an iterator over all modules in the network.

        Yields:

            Module: a module in the network

        Note:

            Duplicate modules are returned only once. In the following

            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)

            >>> net = nn.Sequential(l, l)

            >>> for idx, m in enumerate(net.modules()):

            ...     print(idx, '->', m)

            0 -> Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            )

            1 -> Linear(in_features=2, out_features=2, bias=True)

        """

        for _, module in self.named_modules():

            yield module

named_buffers

def named_buffers(
    self,
    prefix: str = '',
    recurse: bool = True,
    remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:

Name Type Description Default
prefix str prefix to prepend to all buffer names. None
recurse bool if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module. Defaults to True.
None
remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True

Yields:

Type Description
None (str, torch.Tensor): Tuple containing the name and buffer
View Source
    def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:

        r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

        Args:

            prefix (str): prefix to prepend to all buffer names.

            recurse (bool, optional): if True, then yields buffers of this module

                and all submodules. Otherwise, yields only buffers that

                are direct members of this module. Defaults to True.

            remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.

        Yields:

            (str, torch.Tensor): Tuple containing the name and buffer

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, buf in self.named_buffers():

            >>>     if name in ['running_var']:

            >>>         print(buf.size())

        """

        gen = self._named_members(

            lambda module: module._buffers.items(),

            prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)

        yield from gen

named_children

def named_children(
    self
) -> Iterator[Tuple[str, ForwardRef('Module')]]

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields:

Type Description
None (str, Module): Tuple containing a name and child module
View Source
    def named_children(self) -> Iterator[Tuple[str, 'Module']]:

        r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

        Yields:

            (str, Module): Tuple containing a name and child module

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, module in model.named_children():

            >>>     if name in ['conv4', 'conv5']:

            >>>         print(module)

        """

        memo = set()

        for name, module in self._modules.items():

            if module is not None and module not in memo:

                memo.add(module)

                yield name, module

named_modules

def named_modules(
    self,
    memo: Optional[Set[ForwardRef('Module')]] = None,
    prefix: str = '',
    remove_duplicate: bool = True
)

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Parameters:

Name Type Description Default
memo None a memo to store the set of modules already added to the result None
prefix None a prefix that will be added to the name of the module None
remove_duplicate None whether to remove the duplicated module instances in the result
or not
None

Yields:

Type Description
None (str, Module): Tuple of name and module
View Source
    def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):

        r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

        Args:

            memo: a memo to store the set of modules already added to the result

            prefix: a prefix that will be added to the name of the module

            remove_duplicate: whether to remove the duplicated module instances in the result

                or not

        Yields:

            (str, Module): Tuple of name and module

        Note:

            Duplicate modules are returned only once. In the following

            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)

            >>> net = nn.Sequential(l, l)

            >>> for idx, m in enumerate(net.named_modules()):

            ...     print(idx, '->', m)

            0 -> ('', Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            ))

            1 -> ('0', Linear(in_features=2, out_features=2, bias=True))

        """

        if memo is None:

            memo = set()

        if self not in memo:

            if remove_duplicate:

                memo.add(self)

            yield prefix, self

            for name, module in self._modules.items():

                if module is None:

                    continue

                submodule_prefix = prefix + ('.' if prefix else '') + name

                yield from module.named_modules(memo, submodule_prefix, remove_duplicate)

named_parameters

def named_parameters(
    self,
    prefix: str = '',
    recurse: bool = True,
    remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:

Name Type Description Default
prefix str prefix to prepend to all parameter names. None
recurse bool if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
None
remove_duplicate bool whether to remove the duplicated
parameters in the result. Defaults to True.
None

Yields:

Type Description
None (str, Parameter): Tuple containing the name and parameter
View Source
    def named_parameters(

            self,

            prefix: str = '',

            recurse: bool = True,

            remove_duplicate: bool = True

    ) -> Iterator[Tuple[str, Parameter]]:

        r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

        Args:

            prefix (str): prefix to prepend to all parameter names.

            recurse (bool): if True, then yields parameters of this module

                and all submodules. Otherwise, yields only parameters that

                are direct members of this module.

            remove_duplicate (bool, optional): whether to remove the duplicated

                parameters in the result. Defaults to True.

        Yields:

            (str, Parameter): Tuple containing the name and parameter

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, param in self.named_parameters():

            >>>     if name in ['bias']:

            >>>         print(param.size())

        """

        gen = self._named_members(

            lambda module: module._parameters.items(),

            prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)

        yield from gen

parameters

def parameters(
    self,
    recurse: bool = True
) -> Iterator[torch.nn.parameter.Parameter]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Parameters:

Name Type Description Default
recurse bool if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
None

Yields:

Type Description
Parameter module parameter
View Source
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:

        r"""Return an iterator over module parameters.

        This is typically passed to an optimizer.

        Args:

            recurse (bool): if True, then yields parameters of this module

                and all submodules. Otherwise, yields only parameters that

                are direct members of this module.

        Yields:

            Parameter: module parameter

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for param in model.parameters():

            >>>     print(type(param), param.size())

            <class 'torch.Tensor'> (20L,)

            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

        """

        for name, param in self.named_parameters(recurse=recurse):

            yield param

register_backward_hook

def register_backward_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]
) -> torch.utils.hooks.RemovableHandle

Register a backward hook on the module.

This function is deprecated in favor of :meth:~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions.

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_backward_hook(

        self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]]

    ) -> RemovableHandle:

        r"""Register a backward hook on the module.

        This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and

        the behavior of this function will change in future versions.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        if self._is_full_backward_hook is True:

            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "

                               "single Module. Please use only one of them.")

        self._is_full_backward_hook = False

        handle = hooks.RemovableHandle(self._backward_hooks)

        self._backward_hooks[handle.id] = hook

        return handle

register_buffer

def register_buffer(
    self,
    name: str,
    tensor: Optional[torch.Tensor],
    persistent: bool = True
) -> None

Add a buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:state_dict.

Buffers can be accessed as attributes using given names.

Parameters:

Name Type Description Default
name str name of the buffer. The buffer can be accessed
from this module using the given name
None
tensor Tensor or None buffer to be registered. If None, then operations
that run on buffers, such as :attr:cuda, are ignored. If None,
the buffer is not included in the module's :attr:state_dict.
None
persistent bool whether the buffer is part of this module's
:attr:state_dict.
None
View Source
    def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:

        r"""Add a buffer to the module.

        This is typically used to register a buffer that should not to be

        considered a model parameter. For example, BatchNorm's ``running_mean``

        is not a parameter, but is part of the module's state. Buffers, by

        default, are persistent and will be saved alongside parameters. This

        behavior can be changed by setting :attr:`persistent` to ``False``. The

        only difference between a persistent buffer and a non-persistent buffer

        is that the latter will not be a part of this module's

        :attr:`state_dict`.

        Buffers can be accessed as attributes using given names.

        Args:

            name (str): name of the buffer. The buffer can be accessed

                from this module using the given name

            tensor (Tensor or None): buffer to be registered. If ``None``, then operations

                that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,

                the buffer is **not** included in the module's :attr:`state_dict`.

            persistent (bool): whether the buffer is part of this module's

                :attr:`state_dict`.

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> self.register_buffer('running_mean', torch.zeros(num_features))

        """

        if persistent is False and isinstance(self, torch.jit.ScriptModule):

            raise RuntimeError("ScriptModule does not support non-persistent buffers")

        if '_buffers' not in self.__dict__:

            raise AttributeError(

                "cannot assign buffer before Module.__init__() call")

        elif not isinstance(name, str):

            raise TypeError(f"buffer name should be a string. Got {torch.typename(name)}")

        elif '.' in name:

            raise KeyError("buffer name can't contain \".\"")

        elif name == '':

            raise KeyError("buffer name can't be empty string \"\"")

        elif hasattr(self, name) and name not in self._buffers:

            raise KeyError(f"attribute '{name}' already exists")

        elif tensor is not None and not isinstance(tensor, torch.Tensor):

            raise TypeError(f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' "

                            "(torch Tensor or None required)"

                            )

        else:

            for hook in _global_buffer_registration_hooks.values():

                output = hook(self, name, tensor)

                if output is not None:

                    tensor = output

            self._buffers[name] = tensor

            if persistent:

                self._non_persistent_buffers_set.discard(name)

            else:

                self._non_persistent_buffers_set.add(name)

register_forward_hook

def register_forward_hook(
    self,
    hook: Union[Callable[[~T, Tuple[Any, ...], Any], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]],
    *,
    prepend: bool = False,
    with_kwargs: bool = False,
    always_call: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a forward hook on the module.

The hook will be called every time after :func:forward has computed an output.

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:forward is called. The hook should have the following signature::

hook(module, args, output) -> None or modified output

If with_kwargs is True, the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature::

hook(module, args, kwargs, output) -> None or modified output

Parameters:

Name Type Description Default
hook Callable The user defined hook to be registered. None
prepend bool If True, the provided hook will be fired
before all existing forward hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this :class:torch.nn.modules.Module. Note that global
forward hooks registered with
:func:register_module_forward_hook will fire before all hooks
registered by this method.
Default: False
None
with_kwargs bool If True, the hook will be passed the
kwargs given to the forward function.
Default: False
None
always_call bool If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_forward_hook(

        self,

        hook: Union[

            Callable[[T, Tuple[Any, ...], Any], Optional[Any]],

            Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],

        ],

        *,

        prepend: bool = False,

        with_kwargs: bool = False,

        always_call: bool = False,

    ) -> RemovableHandle:

        r"""Register a forward hook on the module.

        The hook will be called every time after :func:`forward` has computed an output.

        If ``with_kwargs`` is ``False`` or not specified, the input contains only

        the positional arguments given to the module. Keyword arguments won't be

        passed to the hooks and only to the ``forward``. The hook can modify the

        output. It can modify the input inplace but it will not have effect on

        forward since this is called after :func:`forward` is called. The hook

        should have the following signature::

            hook(module, args, output) -> None or modified output

        If ``with_kwargs`` is ``True``, the forward hook will be passed the

        ``kwargs`` given to the forward function and be expected to return the

        output possibly modified. The hook should have the following signature::

            hook(module, args, kwargs, output) -> None or modified output

        Args:

            hook (Callable): The user defined hook to be registered.

            prepend (bool): If ``True``, the provided ``hook`` will be fired

                before all existing ``forward`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``forward`` hooks on

                this :class:`torch.nn.modules.Module`. Note that global

                ``forward`` hooks registered with

                :func:`register_module_forward_hook` will fire before all hooks

                registered by this method.

                Default: ``False``

            with_kwargs (bool): If ``True``, the ``hook`` will be passed the

                kwargs given to the forward function.

                Default: ``False``

            always_call (bool): If ``True`` the ``hook`` will be run regardless of

                whether an exception is raised while calling the Module.

                Default: ``False``

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(

            self._forward_hooks,

            extra_dict=[self._forward_hooks_with_kwargs, self._forward_hooks_always_called],

        )

        self._forward_hooks[handle.id] = hook

        if with_kwargs:

            self._forward_hooks_with_kwargs[handle.id] = True

        if always_call:

            self._forward_hooks_always_called[handle.id] = True

        if prepend:

            self._forward_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_forward_pre_hook

def register_forward_pre_hook(
    self,
    hook: Union[Callable[[~T, Tuple[Any, ...]], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]],
    *,
    prepend: bool = False,
    with_kwargs: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a forward pre-hook on the module.

The hook will be called every time before :func:forward is invoked.

If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature::

hook(module, args) -> None or modified input

If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature::

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs

Parameters:

Name Type Description Default
hook Callable The user defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing forward_pre hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this :class:torch.nn.modules.Module. Note that global
forward_pre hooks registered with
:func:register_module_forward_pre_hook will fire before all
hooks registered by this method.
Default: False
None
with_kwargs bool If true, the hook will be passed the kwargs
given to the forward function.
Default: False
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_forward_pre_hook(

        self,

        hook: Union[

            Callable[[T, Tuple[Any, ...]], Optional[Any]],

            Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]],

        ],

        *,

        prepend: bool = False,

        with_kwargs: bool = False,

    ) -> RemovableHandle:

        r"""Register a forward pre-hook on the module.

        The hook will be called every time before :func:`forward` is invoked.

        If ``with_kwargs`` is false or not specified, the input contains only

        the positional arguments given to the module. Keyword arguments won't be

        passed to the hooks and only to the ``forward``. The hook can modify the

        input. User can either return a tuple or a single modified value in the

        hook. We will wrap the value into a tuple if a single value is returned

        (unless that value is already a tuple). The hook should have the

        following signature::

            hook(module, args) -> None or modified input

        If ``with_kwargs`` is true, the forward pre-hook will be passed the

        kwargs given to the forward function. And if the hook modifies the

        input, both the args and kwargs should be returned. The hook should have

        the following signature::

            hook(module, args, kwargs) -> None or a tuple of modified input and kwargs

        Args:

            hook (Callable): The user defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``forward_pre`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``forward_pre`` hooks

                on this :class:`torch.nn.modules.Module`. Note that global

                ``forward_pre`` hooks registered with

                :func:`register_module_forward_pre_hook` will fire before all

                hooks registered by this method.

                Default: ``False``

            with_kwargs (bool): If true, the ``hook`` will be passed the kwargs

                given to the forward function.

                Default: ``False``

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(

            self._forward_pre_hooks,

            extra_dict=self._forward_pre_hooks_with_kwargs

        )

        self._forward_pre_hooks[handle.id] = hook

        if with_kwargs:

            self._forward_pre_hooks_with_kwargs[handle.id] = True

        if prepend:

            self._forward_pre_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_full_backward_hook

def register_full_backward_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]],
    prepend: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a backward hook on the module.

The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature::

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The :attr:grad_input and :attr:grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:grad_input in subsequent computations. :attr:grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:grad_input and :attr:grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.

.. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.

Parameters:

Name Type Description Default
hook Callable The user-defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing backward hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this :class:torch.nn.modules.Module. Note that global
backward hooks registered with
:func:register_module_full_backward_hook will fire before
all hooks registered by this method.
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_full_backward_hook(

        self,

        hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]],

        prepend: bool = False,

    ) -> RemovableHandle:

        r"""Register a backward hook on the module.

        The hook will be called every time the gradients with respect to a module

        are computed, i.e. the hook will execute if and only if the gradients with

        respect to module outputs are computed. The hook should have the following

        signature::

            hook(module, grad_input, grad_output) -> tuple(Tensor) or None

        The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients

        with respect to the inputs and outputs respectively. The hook should

        not modify its arguments, but it can optionally return a new gradient with

        respect to the input that will be used in place of :attr:`grad_input` in

        subsequent computations. :attr:`grad_input` will only correspond to the inputs given

        as positional arguments and all kwarg arguments are ignored. Entries

        in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor

        arguments.

        For technical reasons, when this hook is applied to a Module, its forward function will

        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view

        of each Tensor returned by the Module's forward function.

        .. warning ::

            Modifying inputs or outputs inplace is not allowed when using backward hooks and

            will raise an error.

        Args:

            hook (Callable): The user-defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``backward`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``backward`` hooks on

                this :class:`torch.nn.modules.Module`. Note that global

                ``backward`` hooks registered with

                :func:`register_module_full_backward_hook` will fire before

                all hooks registered by this method.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        if self._is_full_backward_hook is False:

            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "

                               "single Module. Please use only one of them.")

        self._is_full_backward_hook = True

        handle = hooks.RemovableHandle(self._backward_hooks)

        self._backward_hooks[handle.id] = hook

        if prepend:

            self._backward_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_full_backward_pre_hook

def register_full_backward_pre_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]],
    prepend: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a backward pre-hook on the module.

The hook will be called every time the gradients for the module are computed. The hook should have the following signature::

hook(module, grad_output) -> tuple[Tensor] or None

The :attr:grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:grad_output in subsequent computations. Entries in :attr:grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.

.. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error.

Parameters:

Name Type Description Default
hook Callable The user-defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing backward_pre hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this :class:torch.nn.modules.Module. Note that global
backward_pre hooks registered with
:func:register_module_full_backward_pre_hook will fire before
all hooks registered by this method.
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_full_backward_pre_hook(

        self,

        hook: Callable[["Module", _grad_t], Union[None, _grad_t]],

        prepend: bool = False,

    ) -> RemovableHandle:

        r"""Register a backward pre-hook on the module.

        The hook will be called every time the gradients for the module are computed.

        The hook should have the following signature::

            hook(module, grad_output) -> tuple[Tensor] or None

        The :attr:`grad_output` is a tuple. The hook should

        not modify its arguments, but it can optionally return a new gradient with

        respect to the output that will be used in place of :attr:`grad_output` in

        subsequent computations. Entries in :attr:`grad_output` will be ``None`` for

        all non-Tensor arguments.

        For technical reasons, when this hook is applied to a Module, its forward function will

        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view

        of each Tensor returned by the Module's forward function.

        .. warning ::

            Modifying inputs inplace is not allowed when using backward hooks and

            will raise an error.

        Args:

            hook (Callable): The user-defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``backward_pre`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``backward_pre`` hooks

                on this :class:`torch.nn.modules.Module`. Note that global

                ``backward_pre`` hooks registered with

                :func:`register_module_full_backward_pre_hook` will fire before

                all hooks registered by this method.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(self._backward_pre_hooks)

        self._backward_pre_hooks[handle.id] = hook

        if prepend:

            self._backward_pre_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_load_state_dict_post_hook

def register_load_state_dict_post_hook(
    self,
    hook
)

Register a post hook to be run after module's load_state_dict is called.

It should have the following signature:: hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys.

The given incompatible_keys can be modified inplace if needed.

Note that the checks performed when calling :func:load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error.

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_load_state_dict_post_hook(self, hook):

        r"""Register a post hook to be run after module's ``load_state_dict`` is called.

        It should have the following signature::

            hook(module, incompatible_keys) -> None

        The ``module`` argument is the current module that this hook is registered

        on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting

        of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``

        is a ``list`` of ``str`` containing the missing keys and

        ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.

        The given incompatible_keys can be modified inplace if needed.

        Note that the checks performed when calling :func:`load_state_dict` with

        ``strict=True`` are affected by modifications the hook makes to

        ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either

        set of keys will result in an error being thrown when ``strict=True``, and

        clearing out both missing and unexpected keys will avoid an error.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(self._load_state_dict_post_hooks)

        self._load_state_dict_post_hooks[handle.id] = hook

        return handle

register_module

def register_module(
    self,
    name: str,
    module: Optional[ForwardRef('Module')]
) -> None

Alias for :func:add_module.

View Source
    def register_module(self, name: str, module: Optional['Module']) -> None:

        r"""Alias for :func:`add_module`."""

        self.add_module(name, module)

register_parameter

def register_parameter(
    self,
    name: str,
    param: Optional[torch.nn.parameter.Parameter]
) -> None

Add a parameter to the module.

The parameter can be accessed as an attribute using given name.

Parameters:

Name Type Description Default
name str name of the parameter. The parameter can be accessed
from this module using the given name
None
param Parameter or None parameter to be added to the module. If
None, then operations that run on parameters, such as :attr:cuda,
are ignored. If None, the parameter is not included in the
module's :attr:state_dict.
None
View Source
    def register_parameter(self, name: str, param: Optional[Parameter]) -> None:

        r"""Add a parameter to the module.

        The parameter can be accessed as an attribute using given name.

        Args:

            name (str): name of the parameter. The parameter can be accessed

                from this module using the given name

            param (Parameter or None): parameter to be added to the module. If

                ``None``, then operations that run on parameters, such as :attr:`cuda`,

                are ignored. If ``None``, the parameter is **not** included in the

                module's :attr:`state_dict`.

        """

        if '_parameters' not in self.__dict__:

            raise AttributeError(

                "cannot assign parameter before Module.__init__() call")

        elif not isinstance(name, str):

            raise TypeError(f"parameter name should be a string. Got {torch.typename(name)}")

        elif '.' in name:

            raise KeyError("parameter name can't contain \".\"")

        elif name == '':

            raise KeyError("parameter name can't be empty string \"\"")

        elif hasattr(self, name) and name not in self._parameters:

            raise KeyError(f"attribute '{name}' already exists")

        if param is None:

            self._parameters[name] = None

        elif not isinstance(param, Parameter):

            raise TypeError(f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "

                            "(torch.nn.Parameter or None required)"

                            )

        elif param.grad_fn:

            raise ValueError(

                f"Cannot assign non-leaf Tensor to parameter '{name}'. Model "

                f"parameters must be created explicitly. To express '{name}' "

                "as a function of another Tensor, compute the value in "

                "the forward() method.")

        else:

            for hook in _global_parameter_registration_hooks.values():

                output = hook(self, name, param)

                if output is not None:

                    param = output

            self._parameters[name] = param

register_state_dict_pre_hook

def register_state_dict_pre_hook(
    self,
    hook
)

Register a pre-hook for the :meth:~torch.nn.Module.state_dict method.

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.

View Source
    def register_state_dict_pre_hook(self, hook):

        r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.

        These hooks will be called with arguments: ``self``, ``prefix``,

        and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered

        hooks can be used to perform pre-processing before the ``state_dict``

        call is made.

        """

        handle = hooks.RemovableHandle(self._state_dict_pre_hooks)

        self._state_dict_pre_hooks[handle.id] = hook

        return handle

requires_grad_

def requires_grad_(
    self: ~T,
    requires_grad: bool = True
) -> ~T

Change if autograd should record operations on parameters in this module.

This method sets the parameters' :attr:requires_grad attributes in-place.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).

See :ref:locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it.

Parameters:

Name Type Description Default
requires_grad bool whether autograd should record operations on
parameters in this module. Default: True.
None

Returns:

Type Description
Module self
View Source
    def requires_grad_(self: T, requires_grad: bool = True) -> T:

        r"""Change if autograd should record operations on parameters in this module.

        This method sets the parameters' :attr:`requires_grad` attributes

        in-place.

        This method is helpful for freezing part of the module for finetuning

        or training parts of a model individually (e.g., GAN training).

        See :ref:`locally-disable-grad-doc` for a comparison between

        `.requires_grad_()` and several similar mechanisms that may be confused with it.

        Args:

            requires_grad (bool): whether autograd should record operations on

                                  parameters in this module. Default: ``True``.

        Returns:

            Module: self

        """

        for p in self.parameters():

            p.requires_grad_(requires_grad)

        return self

set_extra_state

def set_extra_state(
    self,
    state: Any
) -> None

Set extra state contained in the loaded state_dict.

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding

View Source
    def set_extra_state(self, state: Any) -> None:

        """Set extra state contained in the loaded `state_dict`.

        This function is called from :func:`load_state_dict` to handle any extra state

        found within the `state_dict`. Implement this function and a corresponding

        :func:`get_extra_state` for your module if you need to store extra state within its

        `state_dict`.

        Args:

            state (dict): Extra state from the `state_dict`

        """

        raise RuntimeError(

            "Reached a code path in Module.set_extra_state() that should never be called. "

            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

            "to report this bug.")

share_memory

def share_memory(
    self: ~T
) -> ~T

See :meth:torch.Tensor.share_memory_.

View Source
    def share_memory(self: T) -> T:

        r"""See :meth:`torch.Tensor.share_memory_`."""

        return self._apply(lambda t: t.share_memory_())

state_dict

def state_dict(
    self,
    *args,
    destination=None,
    prefix='',
    keep_vars=False
)

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

.. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers.

.. warning:: Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

.. warning:: Please avoid the use of argument destination as it is not designed for end-users.

Parameters:

Name Type Description Default
destination dict If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
None
prefix str a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
None
keep_vars bool by default the :class:~torch.Tensor s
returned in the state dict are detached from autograd. If it's
set to True, detaching will not be performed.
Default: False.
None

Returns:

Type Description
dict a dictionary containing a whole state of the module
View Source
    def state_dict(self, *args, destination=None, prefix='', keep_vars=False):

        r"""Return a dictionary containing references to the whole state of the module.

        Both parameters and persistent buffers (e.g. running averages) are

        included. Keys are corresponding parameter and buffer names.

        Parameters and buffers set to ``None`` are not included.

        .. note::

            The returned object is a shallow copy. It contains references

            to the module's parameters and buffers.

        .. warning::

            Currently ``state_dict()`` also accepts positional arguments for

            ``destination``, ``prefix`` and ``keep_vars`` in order. However,

            this is being deprecated and keyword arguments will be enforced in

            future releases.

        .. warning::

            Please avoid the use of argument ``destination`` as it is not

            designed for end-users.

        Args:

            destination (dict, optional): If provided, the state of module will

                be updated into the dict and the same object is returned.

                Otherwise, an ``OrderedDict`` will be created and returned.

                Default: ``None``.

            prefix (str, optional): a prefix added to parameter and buffer

                names to compose the keys in state_dict. Default: ``''``.

            keep_vars (bool, optional): by default the :class:`~torch.Tensor` s

                returned in the state dict are detached from autograd. If it's

                set to ``True``, detaching will not be performed.

                Default: ``False``.

        Returns:

            dict:

                a dictionary containing a whole state of the module

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> module.state_dict().keys()

            ['bias', 'weight']

        """

        # TODO: Remove `args` and the parsing logic when BC allows.

        if len(args) > 0:

            if destination is None:

                destination = args[0]

            if len(args) > 1 and prefix == '':

                prefix = args[1]

            if len(args) > 2 and keep_vars is False:

                keep_vars = args[2]

            # DeprecationWarning is ignored by default

            warnings.warn(

                "Positional args are being deprecated, use kwargs instead. Refer to "

                "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"

                " for details.")

        if destination is None:

            destination = OrderedDict()

            destination._metadata = OrderedDict()

        local_metadata = dict(version=self._version)

        if hasattr(destination, "_metadata"):

            destination._metadata[prefix[:-1]] = local_metadata

        for hook in self._state_dict_pre_hooks.values():

            hook(self, prefix, keep_vars)

        self._save_to_state_dict(destination, prefix, keep_vars)

        for name, module in self._modules.items():

            if module is not None:

                module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)

        for hook in self._state_dict_hooks.values():

            hook_result = hook(self, destination, prefix, local_metadata)

            if hook_result is not None:

                destination = hook_result

        return destination

to

def to(
    self,
    *args,
    **kwargs
)

Move and/or cast the parameters and buffers.

This can be called as

.. function:: to(device=None, dtype=None, non_blocking=False) :noindex:

.. function:: to(dtype, non_blocking=False) :noindex:

.. function:: to(tensor, non_blocking=False) :noindex:

.. function:: to(memory_format=torch.channels_last) :noindex:

Its signature is similar to :meth:torch.Tensor.to, but only accepts floating point or complex :attr:dtype\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:dtype (if given). The integral parameters and buffers will be moved :attr:device, if that is given, but with dtypes unchanged. When :attr:non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device ( None class:torch.device): the desired device of the parameters
and buffers in this module
None
dtype ( None class:torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
None
tensor torch.Tensor Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
None
memory_format ( None class:torch.memory_format): the desired memory
format for 4D parameters and buffers in this module (keyword
only argument)
None

Returns:

Type Description
Module self
View Source
    def to(self, *args, **kwargs):

        r"""Move and/or cast the parameters and buffers.

        This can be called as

        .. function:: to(device=None, dtype=None, non_blocking=False)

           :noindex:

        .. function:: to(dtype, non_blocking=False)

           :noindex:

        .. function:: to(tensor, non_blocking=False)

           :noindex:

        .. function:: to(memory_format=torch.channels_last)

           :noindex:

        Its signature is similar to :meth:`torch.Tensor.to`, but only accepts

        floating point or complex :attr:`dtype`\ s. In addition, this method will

        only cast the floating point or complex parameters and buffers to :attr:`dtype`

        (if given). The integral parameters and buffers will be moved

        :attr:`device`, if that is given, but with dtypes unchanged. When

        :attr:`non_blocking` is set, it tries to convert/move asynchronously

        with respect to the host if possible, e.g., moving CPU Tensors with

        pinned memory to CUDA devices.

        See below for examples.

        .. note::

            This method modifies the module in-place.

        Args:

            device (:class:`torch.device`): the desired device of the parameters

                and buffers in this module

            dtype (:class:`torch.dtype`): the desired floating point or complex dtype of

                the parameters and buffers in this module

            tensor (torch.Tensor): Tensor whose dtype and device are the desired

                dtype and device for all parameters and buffers in this module

            memory_format (:class:`torch.memory_format`): the desired memory

                format for 4D parameters and buffers in this module (keyword

                only argument)

        Returns:

            Module: self

        Examples::

            >>> # xdoctest: +IGNORE_WANT("non-deterministic")

            >>> linear = nn.Linear(2, 2)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1913, -0.3420],

                    [-0.5113, -0.2325]])

            >>> linear.to(torch.double)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1913, -0.3420],

                    [-0.5113, -0.2325]], dtype=torch.float64)

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)

            >>> gpu1 = torch.device("cuda:1")

            >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1914, -0.3420],

                    [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')

            >>> cpu = torch.device("cpu")

            >>> linear.to(cpu)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1914, -0.3420],

                    [-0.5112, -0.2324]], dtype=torch.float16)

            >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.3741+0.j,  0.2382+0.j],

                    [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)

            >>> linear(torch.ones(3, 2, dtype=torch.cdouble))

            tensor([[0.6122+0.j, 0.1150+0.j],

                    [0.6122+0.j, 0.1150+0.j],

                    [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)

        """

        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

        if dtype is not None:

            if not (dtype.is_floating_point or dtype.is_complex):

                raise TypeError('nn.Module.to only accepts floating point or complex '

                                f'dtypes, but got desired dtype={dtype}')

            if dtype.is_complex:

                warnings.warn(

                    "Complex modules are a new feature under active development whose design may change, "

                    "and some modules might not work as expected when using complex tensors as parameters or buffers. "

                    "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

                    "if a complex module does not work as expected.")

        def convert(t):

            try:

                if convert_to_format is not None and t.dim() in (4, 5):

                    return t.to(

                        device,

                        dtype if t.is_floating_point() or t.is_complex() else None,

                        non_blocking,

                        memory_format=convert_to_format,

                    )

                return t.to(

                    device,

                    dtype if t.is_floating_point() or t.is_complex() else None,

                    non_blocking,

                )

            except NotImplementedError as e:

                if str(e) == "Cannot copy out of meta tensor; no data!":

                    raise NotImplementedError(

                        f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "

                        f"when moving module from meta to a different device."

                    ) from None

                else:

                    raise

        return self._apply(convert)

to_empty

def to_empty(
    self: ~T,
    *,
    device: Union[int, str, torch.device, NoneType],
    recurse: bool = True
) -> ~T

Move the parameters and buffers to the specified device without copying storage.

Parameters:

Name Type Description Default
device ( None class:torch.device): The desired device of the parameters
and buffers in this module.
None
recurse bool Whether parameters and buffers of submodules should
be recursively moved to the specified device.
None

Returns:

Type Description
Module self
View Source
    def to_empty(self: T, *, device: Optional[DeviceLikeType], recurse: bool = True) -> T:

        r"""Move the parameters and buffers to the specified device without copying storage.

        Args:

            device (:class:`torch.device`): The desired device of the parameters

                and buffers in this module.

            recurse (bool): Whether parameters and buffers of submodules should

                be recursively moved to the specified device.

        Returns:

            Module: self

        """

        return self._apply(lambda t: torch.empty_like(t, device=device), recurse=recurse)

train

def train(
    self: ~T,
    mode: bool = True
) -> ~T

Set the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

Parameters:

Name Type Description Default
mode bool whether to set training mode (True) or evaluation
mode (False). Default: True.
None

Returns:

Type Description
Module self
View Source
    def train(self: T, mode: bool = True) -> T:

        r"""Set the module in training mode.

        This has any effect only on certain modules. See documentations of

        particular modules for details of their behaviors in training/evaluation

        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,

        etc.

        Args:

            mode (bool): whether to set training mode (``True``) or evaluation

                         mode (``False``). Default: ``True``.

        Returns:

            Module: self

        """

        if not isinstance(mode, bool):

            raise ValueError("training mode is expected to be boolean")

        self.training = mode

        for module in self.children():

            module.train(mode)

        return self

type

def type(
    self: ~T,
    dst_type: Union[torch.dtype, str]
) -> ~T

Casts all parameters and buffers to :attr:dst_type.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
dst_type type or string the desired type None

Returns:

Type Description
Module self
View Source
    def type(self: T, dst_type: Union[dtype, str]) -> T:

        r"""Casts all parameters and buffers to :attr:`dst_type`.

        .. note::

            This method modifies the module in-place.

        Args:

            dst_type (type or string): the desired type

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.type(dst_type))

xpu

def xpu(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the XPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the XPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on XPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Arguments:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.xpu(device))

zero_grad

def zero_grad(
    self,
    set_to_none: bool = True
) -> None

Reset gradients of all model parameters.

See similar function under :class:torch.optim.Optimizer for more context.

Parameters:

Name Type Description Default
set_to_none bool instead of setting to zero, set the grads to None.
See :meth:torch.optim.Optimizer.zero_grad for details.
None
View Source
    def zero_grad(self, set_to_none: bool = True) -> None:

        r"""Reset gradients of all model parameters.

        See similar function under :class:`torch.optim.Optimizer` for more context.

        Args:

            set_to_none (bool): instead of setting to zero, set the grads to None.

                See :meth:`torch.optim.Optimizer.zero_grad` for details.

        """

        if getattr(self, '_is_replica', False):

            warnings.warn(

                "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "

                "The parameters are copied (in a differentiable manner) from the original module. "

                "This means they are not leaf nodes in autograd and so don't accumulate gradients. "

                "If you need gradients in your forward method, consider using autograd.grad instead.")

        for p in self.parameters():

            if p.grad is not None:

                if set_to_none:

                    p.grad = None

                else:

                    if p.grad.grad_fn is not None:

                        p.grad.detach_()

                    else:

                        p.grad.requires_grad_(False)

                    p.grad.zero_()

RMLP

class RMLP(
    block_in_dim: int,
    block_dims: Sequence[int],
    block_nonlins: Sequence[Union[str, torch.nn.modules.module.Module]],
    n_blocks: int,
    out_dim: int,
    in_dim: int = None,
    batch_norm: bool = True
)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

View Source
class RMLP(nn.Module):

    def __init__(

        self,

        block_in_dim: int,

        block_dims: Sequence[int],

        block_nonlins: Sequence[Union[str, nn.Module]],

        n_blocks: int,

        out_dim: int,

        in_dim: int = None,  # if in_dim is an int, then a first layer will be made

        batch_norm: bool = True,

    ) -> None:

        super().__init__()

        # Create first layer if in_dim is not None

        self.input = nn.Identity()

        if in_dim is not None:

            self.input = MLPLayer(in_dim, block_in_dim, block_nonlins[0], batch_norm)

        # Create blocks

        layers = []

        for i in range(n_blocks):

            layers.append(MlpBlock(block_in_dim, block_dims, block_nonlins, batch_norm))

        self.blocks = nn.ModuleList(layers)

        # Create output layer

        self.output = nn.Linear(block_dims[-1], out_dim)

    def _make_activation(self, act: Union[str, nn.Module]) -> nn.Module:

        if isinstance(act, str):

            return ACTIVATIONS[act](**ACTIVATION_DEFAULT_KWARGS[act])

        return act

    def forward(self, x: Tensor) -> Tensor:

        """

        Args:

            x: An input tensor, of shape (N, D) containing N samples with D features.

        Returns:

            An output tensor of shape (N, D_out) where D_out is the output dim.

        """

        x = self.input(x)

        for block in self.blocks:

            out = block(x)

            x = x + out

        return self.output(x)

Ancestors (in MRO)

  • torch.nn.modules.module.Module

Class variables

T_destination
call_super_init
dump_patches

Methods

add_module

def add_module(
    self,
    name: str,
    module: Optional[ForwardRef('Module')]
) -> None

Add a child module to the current module.

The module can be accessed as an attribute using the given name.

Parameters:

Name Type Description Default
name str name of the child module. The child module can be
accessed from this module using the given name
None
module Module child module to be added to the module. None
View Source
    def add_module(self, name: str, module: Optional['Module']) -> None:

        r"""Add a child module to the current module.

        The module can be accessed as an attribute using the given name.

        Args:

            name (str): name of the child module. The child module can be

                accessed from this module using the given name

            module (Module): child module to be added to the module.

        """

        if not isinstance(module, Module) and module is not None:

            raise TypeError(f"{torch.typename(module)} is not a Module subclass")

        elif not isinstance(name, str):

            raise TypeError(f"module name should be a string. Got {torch.typename(name)}")

        elif hasattr(self, name) and name not in self._modules:

            raise KeyError(f"attribute '{name}' already exists")

        elif '.' in name:

            raise KeyError(f"module name can't contain \".\", got: {name}")

        elif name == '':

            raise KeyError("module name can't be empty string \"\"")

        for hook in _global_module_registration_hooks.values():

            output = hook(self, name, module)

            if output is not None:

                module = output

        self._modules[name] = module

apply

def apply(
    self: ~T,
    fn: Callable[[ForwardRef('Module')], NoneType]
) -> ~T

Apply fn recursively to every submodule (as returned by .children()) as well as self.

Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).

Parameters:

Name Type Description Default
fn ( None class:Module -> None): function to be applied to each submodule None

Returns:

Type Description
Module self
View Source
    def apply(self: T, fn: Callable[['Module'], None]) -> T:

        r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.

        Typical use includes initializing the parameters of a model

        (see also :ref:`nn-init-doc`).

        Args:

            fn (:class:`Module` -> None): function to be applied to each submodule

        Returns:

            Module: self

        Example::

            >>> @torch.no_grad()

            >>> def init_weights(m):

            >>>     print(m)

            >>>     if type(m) == nn.Linear:

            >>>         m.weight.fill_(1.0)

            >>>         print(m.weight)

            >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

            >>> net.apply(init_weights)

            Linear(in_features=2, out_features=2, bias=True)

            Parameter containing:

            tensor([[1., 1.],

                    [1., 1.]], requires_grad=True)

            Linear(in_features=2, out_features=2, bias=True)

            Parameter containing:

            tensor([[1., 1.],

                    [1., 1.]], requires_grad=True)

            Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            )

        """

        for module in self.children():

            module.apply(fn)

        fn(self)

        return self

bfloat16

def bfloat16(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def bfloat16(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)

buffers

def buffers(
    self,
    recurse: bool = True
) -> Iterator[torch.Tensor]

Return an iterator over module buffers.

Parameters:

Name Type Description Default
recurse bool if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
None

Yields:

Type Description
torch.Tensor module buffer
View Source
    def buffers(self, recurse: bool = True) -> Iterator[Tensor]:

        r"""Return an iterator over module buffers.

        Args:

            recurse (bool): if True, then yields buffers of this module

                and all submodules. Otherwise, yields only buffers that

                are direct members of this module.

        Yields:

            torch.Tensor: module buffer

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for buf in model.buffers():

            >>>     print(type(buf), buf.size())

            <class 'torch.Tensor'> (20L,)

            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

        """

        for _, buf in self.named_buffers(recurse=recurse):

            yield buf

children

def children(
    self
) -> Iterator[ForwardRef('Module')]

Return an iterator over immediate children modules.

Yields:

Type Description
Module a child module
View Source
    def children(self) -> Iterator['Module']:

        r"""Return an iterator over immediate children modules.

        Yields:

            Module: a child module

        """

        for name, module in self.named_children():

            yield module

compile

def compile(
    self,
    *args,
    **kwargs
)

Compile this Module's forward using :func:torch.compile.

This Module's __call__ method is compiled and all arguments are passed as-is to :func:torch.compile.

See :func:torch.compile for details on the arguments for this function.

View Source
    def compile(self, *args, **kwargs):

        """

        Compile this Module's forward using :func:`torch.compile`.

        This Module's `__call__` method is compiled and all arguments are passed as-is

        to :func:`torch.compile`.

        See :func:`torch.compile` for details on the arguments for this function.

        """

        self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs)

cpu

def cpu(
    self: ~T
) -> ~T

Move all model parameters and buffers to the CPU.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def cpu(self: T) -> T:

        r"""Move all model parameters and buffers to the CPU.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.cpu())

cuda

def cuda(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the GPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on GPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Args:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.cuda(device))

double

def double(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to double datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def double(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``double`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.double() if t.is_floating_point() else t)

eval

def eval(
    self: ~T
) -> ~T

Set the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

This is equivalent with :meth:self.train(False) <torch.nn.Module.train>.

See :ref:locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it.

Returns:

Type Description
Module self
View Source
    def eval(self: T) -> T:

        r"""Set the module in evaluation mode.

        This has any effect only on certain modules. See documentations of

        particular modules for details of their behaviors in training/evaluation

        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,

        etc.

        This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.

        See :ref:`locally-disable-grad-doc` for a comparison between

        `.eval()` and several similar mechanisms that may be confused with it.

        Returns:

            Module: self

        """

        return self.train(False)

extra_repr

def extra_repr(
    self
) -> str

Set the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

View Source
    def extra_repr(self) -> str:

        r"""Set the extra representation of the module.

        To print customized extra information, you should re-implement

        this method in your own modules. Both single-line and multi-line

        strings are acceptable.

        """

        return ''

float

def float(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to float datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def float(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``float`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.float() if t.is_floating_point() else t)

forward

def forward(
    self,
    x: torch.Tensor
) -> torch.Tensor

Parameters:

Name Type Description Default
x None An input tensor, of shape (N, D) containing N samples with D features. None

Returns:

Type Description
None An output tensor of shape (N, D_out) where D_out is the output dim.
View Source
    def forward(self, x: Tensor) -> Tensor:

        """

        Args:

            x: An input tensor, of shape (N, D) containing N samples with D features.

        Returns:

            An output tensor of shape (N, D_out) where D_out is the output dim.

        """

        x = self.input(x)

        for block in self.blocks:

            out = block(x)

            x = x + out

        return self.output(x)

get_buffer

def get_buffer(
    self,
    target: str
) -> 'Tensor'

Return the buffer given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the buffer
to look for. (See get_submodule for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.Tensor The buffer referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not a
buffer
View Source
    def get_buffer(self, target: str) -> "Tensor":

        """Return the buffer given by ``target`` if it exists, otherwise throw an error.

        See the docstring for ``get_submodule`` for a more detailed

        explanation of this method's functionality as well as how to

        correctly specify ``target``.

        Args:

            target: The fully-qualified string name of the buffer

                to look for. (See ``get_submodule`` for how to specify a

                fully-qualified string.)

        Returns:

            torch.Tensor: The buffer referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not a

                buffer

        """

        module_path, _, buffer_name = target.rpartition(".")

        mod: torch.nn.Module = self.get_submodule(module_path)

        if not hasattr(mod, buffer_name):

            raise AttributeError(mod._get_name() + " has no attribute `"

                                 + buffer_name + "`")

        buffer: torch.Tensor = getattr(mod, buffer_name)

        if buffer_name not in mod._buffers:

            raise AttributeError("`" + buffer_name + "` is not a buffer")

        return buffer

get_extra_state

def get_extra_state(
    self
) -> Any

Return any extra state to include in the module's state_dict.

Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns:

Type Description
object Any extra state to store in the module's state_dict
View Source
    def get_extra_state(self) -> Any:

        """Return any extra state to include in the module's state_dict.

        Implement this and a corresponding :func:`set_extra_state` for your module

        if you need to store extra state. This function is called when building the

        module's `state_dict()`.

        Note that extra state should be picklable to ensure working serialization

        of the state_dict. We only provide provide backwards compatibility guarantees

        for serializing Tensors; other objects may break backwards compatibility if

        their serialized pickled form changes.

        Returns:

            object: Any extra state to store in the module's state_dict

        """

        raise RuntimeError(

            "Reached a code path in Module.get_extra_state() that should never be called. "

            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

            "to report this bug.")

get_parameter

def get_parameter(
    self,
    target: str
) -> 'Parameter'

Return the parameter given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the Parameter
to look for. (See get_submodule for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.nn.Parameter The Parameter referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not an
nn.Parameter
View Source
    def get_parameter(self, target: str) -> "Parameter":

        """Return the parameter given by ``target`` if it exists, otherwise throw an error.

        See the docstring for ``get_submodule`` for a more detailed

        explanation of this method's functionality as well as how to

        correctly specify ``target``.

        Args:

            target: The fully-qualified string name of the Parameter

                to look for. (See ``get_submodule`` for how to specify a

                fully-qualified string.)

        Returns:

            torch.nn.Parameter: The Parameter referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not an

                ``nn.Parameter``

        """

        module_path, _, param_name = target.rpartition(".")

        mod: torch.nn.Module = self.get_submodule(module_path)

        if not hasattr(mod, param_name):

            raise AttributeError(mod._get_name() + " has no attribute `"

                                 + param_name + "`")

        param: torch.nn.Parameter = getattr(mod, param_name)

        if not isinstance(param, torch.nn.Parameter):

            raise AttributeError("`" + param_name + "` is not an "

                                 "nn.Parameter")

        return param

get_submodule

def get_submodule(
    self,
    target: str
) -> 'Module'

Return the submodule given by target if it exists, otherwise throw an error.

For example, let's say you have an nn.Module A that looks like this:

.. code-block:: text

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(The diagram shows an nn.Module A. A has a nested submodule net_b, which itself has two submodules net_c and linear. net_c then has a submodule conv.)

To check whether or not we have the linear submodule, we would call get_submodule("net_b.linear"). To check whether we have the conv submodule, we would call get_submodule("net_b.net_c.conv").

The runtime of get_submodule is bounded by the degree of module nesting in target. A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.nn.Module The submodule referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not an
nn.Module
View Source
    def get_submodule(self, target: str) -> "Module":

        """Return the submodule given by ``target`` if it exists, otherwise throw an error.

        For example, let's say you have an ``nn.Module`` ``A`` that

        looks like this:

        .. code-block:: text

            A(

                (net_b): Module(

                    (net_c): Module(

                        (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))

                    )

                    (linear): Linear(in_features=100, out_features=200, bias=True)

                )

            )

        (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested

        submodule ``net_b``, which itself has two submodules ``net_c``

        and ``linear``. ``net_c`` then has a submodule ``conv``.)

        To check whether or not we have the ``linear`` submodule, we

        would call ``get_submodule("net_b.linear")``. To check whether

        we have the ``conv`` submodule, we would call

        ``get_submodule("net_b.net_c.conv")``.

        The runtime of ``get_submodule`` is bounded by the degree

        of module nesting in ``target``. A query against

        ``named_modules`` achieves the same result, but it is O(N) in

        the number of transitive modules. So, for a simple check to see

        if some submodule exists, ``get_submodule`` should always be

        used.

        Args:

            target: The fully-qualified string name of the submodule

                to look for. (See above example for how to specify a

                fully-qualified string.)

        Returns:

            torch.nn.Module: The submodule referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not an

                ``nn.Module``

        """

        if target == "":

            return self

        atoms: List[str] = target.split(".")

        mod: torch.nn.Module = self

        for item in atoms:

            if not hasattr(mod, item):

                raise AttributeError(mod._get_name() + " has no "

                                     "attribute `" + item + "`")

            mod = getattr(mod, item)

            if not isinstance(mod, torch.nn.Module):

                raise AttributeError("`" + item + "` is not "

                                     "an nn.Module")

        return mod

half

def half(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to half datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def half(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``half`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.half() if t.is_floating_point() else t)

ipu

def ipu(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the IPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the IPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on IPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Arguments:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.ipu(device))

load_state_dict

def load_state_dict(
    self,
    state_dict: Mapping[str, Any],
    strict: bool = True,
    assign: bool = False
)

Copy parameters and buffers from :attr:state_dict into this module and its descendants.

If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.

.. warning:: If :attr:assign is True the optimizer must be created after the call to :attr:load_state_dict unless :func:~torch.__future__.get_swap_module_params_on_conversion is True.

Parameters:

Name Type Description Default
state_dict dict a dict containing parameters and
persistent buffers.
None
strict bool whether to strictly enforce that the keys
in :attr:state_dict match the keys returned by this module's
:meth:~torch.nn.Module.state_dict function. Default: True
None
assign bool When False, the properties of the tensors
in the current module are preserved while when True, the
properties of the Tensors in the state dict are preserved. The only
exception is the requires_grad field of :class:~torch.nn.Parameters
for which the value from the module is preserved.
Default: False
None

Returns:

Type Description
None NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
View Source
    def load_state_dict(self, state_dict: Mapping[str, Any],

                        strict: bool = True, assign: bool = False):

        r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

        If :attr:`strict` is ``True``, then

        the keys of :attr:`state_dict` must exactly match the keys returned

        by this module's :meth:`~torch.nn.Module.state_dict` function.

        .. warning::

            If :attr:`assign` is ``True`` the optimizer must be created after

            the call to :attr:`load_state_dict` unless

            :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.

        Args:

            state_dict (dict): a dict containing parameters and

                persistent buffers.

            strict (bool, optional): whether to strictly enforce that the keys

                in :attr:`state_dict` match the keys returned by this module's

                :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

            assign (bool, optional): When ``False``, the properties of the tensors

                in the current module are preserved while when ``True``, the

                properties of the Tensors in the state dict are preserved. The only

                exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s

                for which the value from the module is preserved.

                Default: ``False``

        Returns:

            ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:

                * **missing_keys** is a list of str containing the missing keys

                * **unexpected_keys** is a list of str containing the unexpected keys

        Note:

            If a parameter or buffer is registered as ``None`` and its corresponding key

            exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a

            ``RuntimeError``.

        """

        if not isinstance(state_dict, Mapping):

            raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")

        missing_keys: List[str] = []

        unexpected_keys: List[str] = []

        error_msgs: List[str] = []

        # copy state_dict so _load_from_state_dict can modify it

        metadata = getattr(state_dict, '_metadata', None)

        state_dict = OrderedDict(state_dict)

        if metadata is not None:

            # mypy isn't aware that "_metadata" exists in state_dict

            state_dict._metadata = metadata  # type: ignore[attr-defined]

        def load(module, local_state_dict, prefix=''):

            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})

            if assign:

                local_metadata['assign_to_params_buffers'] = assign

            module._load_from_state_dict(

                local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)

            for name, child in module._modules.items():

                if child is not None:

                    child_prefix = prefix + name + '.'

                    child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}

                    load(child, child_state_dict, child_prefix)  # noqa: F821

            # Note that the hook can modify missing_keys and unexpected_keys.

            incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

            for hook in module._load_state_dict_post_hooks.values():

                out = hook(module, incompatible_keys)

                assert out is None, (

                    "Hooks registered with ``register_load_state_dict_post_hook`` are not"

                    "expected to return new values, if incompatible_keys need to be modified,"

                    "it should be done inplace."

                )

        load(self, state_dict)

        del load

        if strict:

            if len(unexpected_keys) > 0:

                error_msgs.insert(

                    0, 'Unexpected key(s) in state_dict: {}. '.format(

                        ', '.join(f'"{k}"' for k in unexpected_keys)))

            if len(missing_keys) > 0:

                error_msgs.insert(

                    0, 'Missing key(s) in state_dict: {}. '.format(

                        ', '.join(f'"{k}"' for k in missing_keys)))

        if len(error_msgs) > 0:

            raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(

                               self.__class__.__name__, "\n\t".join(error_msgs)))

        return _IncompatibleKeys(missing_keys, unexpected_keys)

modules

def modules(
    self
) -> Iterator[ForwardRef('Module')]

Return an iterator over all modules in the network.

Yields:

Type Description
Module a module in the network
View Source
    def modules(self) -> Iterator['Module']:

        r"""Return an iterator over all modules in the network.

        Yields:

            Module: a module in the network

        Note:

            Duplicate modules are returned only once. In the following

            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)

            >>> net = nn.Sequential(l, l)

            >>> for idx, m in enumerate(net.modules()):

            ...     print(idx, '->', m)

            0 -> Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            )

            1 -> Linear(in_features=2, out_features=2, bias=True)

        """

        for _, module in self.named_modules():

            yield module

named_buffers

def named_buffers(
    self,
    prefix: str = '',
    recurse: bool = True,
    remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:

Name Type Description Default
prefix str prefix to prepend to all buffer names. None
recurse bool if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module. Defaults to True.
None
remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True

Yields:

Type Description
None (str, torch.Tensor): Tuple containing the name and buffer
View Source
    def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:

        r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

        Args:

            prefix (str): prefix to prepend to all buffer names.

            recurse (bool, optional): if True, then yields buffers of this module

                and all submodules. Otherwise, yields only buffers that

                are direct members of this module. Defaults to True.

            remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.

        Yields:

            (str, torch.Tensor): Tuple containing the name and buffer

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, buf in self.named_buffers():

            >>>     if name in ['running_var']:

            >>>         print(buf.size())

        """

        gen = self._named_members(

            lambda module: module._buffers.items(),

            prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)

        yield from gen

named_children

def named_children(
    self
) -> Iterator[Tuple[str, ForwardRef('Module')]]

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields:

Type Description
None (str, Module): Tuple containing a name and child module
View Source
    def named_children(self) -> Iterator[Tuple[str, 'Module']]:

        r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

        Yields:

            (str, Module): Tuple containing a name and child module

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, module in model.named_children():

            >>>     if name in ['conv4', 'conv5']:

            >>>         print(module)

        """

        memo = set()

        for name, module in self._modules.items():

            if module is not None and module not in memo:

                memo.add(module)

                yield name, module

named_modules

def named_modules(
    self,
    memo: Optional[Set[ForwardRef('Module')]] = None,
    prefix: str = '',
    remove_duplicate: bool = True
)

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Parameters:

Name Type Description Default
memo None a memo to store the set of modules already added to the result None
prefix None a prefix that will be added to the name of the module None
remove_duplicate None whether to remove the duplicated module instances in the result
or not
None

Yields:

Type Description
None (str, Module): Tuple of name and module
View Source
    def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):

        r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

        Args:

            memo: a memo to store the set of modules already added to the result

            prefix: a prefix that will be added to the name of the module

            remove_duplicate: whether to remove the duplicated module instances in the result

                or not

        Yields:

            (str, Module): Tuple of name and module

        Note:

            Duplicate modules are returned only once. In the following

            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)

            >>> net = nn.Sequential(l, l)

            >>> for idx, m in enumerate(net.named_modules()):

            ...     print(idx, '->', m)

            0 -> ('', Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            ))

            1 -> ('0', Linear(in_features=2, out_features=2, bias=True))

        """

        if memo is None:

            memo = set()

        if self not in memo:

            if remove_duplicate:

                memo.add(self)

            yield prefix, self

            for name, module in self._modules.items():

                if module is None:

                    continue

                submodule_prefix = prefix + ('.' if prefix else '') + name

                yield from module.named_modules(memo, submodule_prefix, remove_duplicate)

named_parameters

def named_parameters(
    self,
    prefix: str = '',
    recurse: bool = True,
    remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:

Name Type Description Default
prefix str prefix to prepend to all parameter names. None
recurse bool if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
None
remove_duplicate bool whether to remove the duplicated
parameters in the result. Defaults to True.
None

Yields:

Type Description
None (str, Parameter): Tuple containing the name and parameter
View Source
    def named_parameters(

            self,

            prefix: str = '',

            recurse: bool = True,

            remove_duplicate: bool = True

    ) -> Iterator[Tuple[str, Parameter]]:

        r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

        Args:

            prefix (str): prefix to prepend to all parameter names.

            recurse (bool): if True, then yields parameters of this module

                and all submodules. Otherwise, yields only parameters that

                are direct members of this module.

            remove_duplicate (bool, optional): whether to remove the duplicated

                parameters in the result. Defaults to True.

        Yields:

            (str, Parameter): Tuple containing the name and parameter

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, param in self.named_parameters():

            >>>     if name in ['bias']:

            >>>         print(param.size())

        """

        gen = self._named_members(

            lambda module: module._parameters.items(),

            prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)

        yield from gen

parameters

def parameters(
    self,
    recurse: bool = True
) -> Iterator[torch.nn.parameter.Parameter]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Parameters:

Name Type Description Default
recurse bool if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
None

Yields:

Type Description
Parameter module parameter
View Source
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:

        r"""Return an iterator over module parameters.

        This is typically passed to an optimizer.

        Args:

            recurse (bool): if True, then yields parameters of this module

                and all submodules. Otherwise, yields only parameters that

                are direct members of this module.

        Yields:

            Parameter: module parameter

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for param in model.parameters():

            >>>     print(type(param), param.size())

            <class 'torch.Tensor'> (20L,)

            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

        """

        for name, param in self.named_parameters(recurse=recurse):

            yield param

register_backward_hook

def register_backward_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]
) -> torch.utils.hooks.RemovableHandle

Register a backward hook on the module.

This function is deprecated in favor of :meth:~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions.

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_backward_hook(

        self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]]

    ) -> RemovableHandle:

        r"""Register a backward hook on the module.

        This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and

        the behavior of this function will change in future versions.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        if self._is_full_backward_hook is True:

            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "

                               "single Module. Please use only one of them.")

        self._is_full_backward_hook = False

        handle = hooks.RemovableHandle(self._backward_hooks)

        self._backward_hooks[handle.id] = hook

        return handle

register_buffer

def register_buffer(
    self,
    name: str,
    tensor: Optional[torch.Tensor],
    persistent: bool = True
) -> None

Add a buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:state_dict.

Buffers can be accessed as attributes using given names.

Parameters:

Name Type Description Default
name str name of the buffer. The buffer can be accessed
from this module using the given name
None
tensor Tensor or None buffer to be registered. If None, then operations
that run on buffers, such as :attr:cuda, are ignored. If None,
the buffer is not included in the module's :attr:state_dict.
None
persistent bool whether the buffer is part of this module's
:attr:state_dict.
None
View Source
    def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:

        r"""Add a buffer to the module.

        This is typically used to register a buffer that should not to be

        considered a model parameter. For example, BatchNorm's ``running_mean``

        is not a parameter, but is part of the module's state. Buffers, by

        default, are persistent and will be saved alongside parameters. This

        behavior can be changed by setting :attr:`persistent` to ``False``. The

        only difference between a persistent buffer and a non-persistent buffer

        is that the latter will not be a part of this module's

        :attr:`state_dict`.

        Buffers can be accessed as attributes using given names.

        Args:

            name (str): name of the buffer. The buffer can be accessed

                from this module using the given name

            tensor (Tensor or None): buffer to be registered. If ``None``, then operations

                that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,

                the buffer is **not** included in the module's :attr:`state_dict`.

            persistent (bool): whether the buffer is part of this module's

                :attr:`state_dict`.

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> self.register_buffer('running_mean', torch.zeros(num_features))

        """

        if persistent is False and isinstance(self, torch.jit.ScriptModule):

            raise RuntimeError("ScriptModule does not support non-persistent buffers")

        if '_buffers' not in self.__dict__:

            raise AttributeError(

                "cannot assign buffer before Module.__init__() call")

        elif not isinstance(name, str):

            raise TypeError(f"buffer name should be a string. Got {torch.typename(name)}")

        elif '.' in name:

            raise KeyError("buffer name can't contain \".\"")

        elif name == '':

            raise KeyError("buffer name can't be empty string \"\"")

        elif hasattr(self, name) and name not in self._buffers:

            raise KeyError(f"attribute '{name}' already exists")

        elif tensor is not None and not isinstance(tensor, torch.Tensor):

            raise TypeError(f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' "

                            "(torch Tensor or None required)"

                            )

        else:

            for hook in _global_buffer_registration_hooks.values():

                output = hook(self, name, tensor)

                if output is not None:

                    tensor = output

            self._buffers[name] = tensor

            if persistent:

                self._non_persistent_buffers_set.discard(name)

            else:

                self._non_persistent_buffers_set.add(name)

register_forward_hook

def register_forward_hook(
    self,
    hook: Union[Callable[[~T, Tuple[Any, ...], Any], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]],
    *,
    prepend: bool = False,
    with_kwargs: bool = False,
    always_call: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a forward hook on the module.

The hook will be called every time after :func:forward has computed an output.

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:forward is called. The hook should have the following signature::

hook(module, args, output) -> None or modified output

If with_kwargs is True, the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature::

hook(module, args, kwargs, output) -> None or modified output

Parameters:

Name Type Description Default
hook Callable The user defined hook to be registered. None
prepend bool If True, the provided hook will be fired
before all existing forward hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this :class:torch.nn.modules.Module. Note that global
forward hooks registered with
:func:register_module_forward_hook will fire before all hooks
registered by this method.
Default: False
None
with_kwargs bool If True, the hook will be passed the
kwargs given to the forward function.
Default: False
None
always_call bool If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_forward_hook(

        self,

        hook: Union[

            Callable[[T, Tuple[Any, ...], Any], Optional[Any]],

            Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],

        ],

        *,

        prepend: bool = False,

        with_kwargs: bool = False,

        always_call: bool = False,

    ) -> RemovableHandle:

        r"""Register a forward hook on the module.

        The hook will be called every time after :func:`forward` has computed an output.

        If ``with_kwargs`` is ``False`` or not specified, the input contains only

        the positional arguments given to the module. Keyword arguments won't be

        passed to the hooks and only to the ``forward``. The hook can modify the

        output. It can modify the input inplace but it will not have effect on

        forward since this is called after :func:`forward` is called. The hook

        should have the following signature::

            hook(module, args, output) -> None or modified output

        If ``with_kwargs`` is ``True``, the forward hook will be passed the

        ``kwargs`` given to the forward function and be expected to return the

        output possibly modified. The hook should have the following signature::

            hook(module, args, kwargs, output) -> None or modified output

        Args:

            hook (Callable): The user defined hook to be registered.

            prepend (bool): If ``True``, the provided ``hook`` will be fired

                before all existing ``forward`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``forward`` hooks on

                this :class:`torch.nn.modules.Module`. Note that global

                ``forward`` hooks registered with

                :func:`register_module_forward_hook` will fire before all hooks

                registered by this method.

                Default: ``False``

            with_kwargs (bool): If ``True``, the ``hook`` will be passed the

                kwargs given to the forward function.

                Default: ``False``

            always_call (bool): If ``True`` the ``hook`` will be run regardless of

                whether an exception is raised while calling the Module.

                Default: ``False``

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(

            self._forward_hooks,

            extra_dict=[self._forward_hooks_with_kwargs, self._forward_hooks_always_called],

        )

        self._forward_hooks[handle.id] = hook

        if with_kwargs:

            self._forward_hooks_with_kwargs[handle.id] = True

        if always_call:

            self._forward_hooks_always_called[handle.id] = True

        if prepend:

            self._forward_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_forward_pre_hook

def register_forward_pre_hook(
    self,
    hook: Union[Callable[[~T, Tuple[Any, ...]], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]],
    *,
    prepend: bool = False,
    with_kwargs: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a forward pre-hook on the module.

The hook will be called every time before :func:forward is invoked.

If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature::

hook(module, args) -> None or modified input

If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature::

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs

Parameters:

Name Type Description Default
hook Callable The user defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing forward_pre hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this :class:torch.nn.modules.Module. Note that global
forward_pre hooks registered with
:func:register_module_forward_pre_hook will fire before all
hooks registered by this method.
Default: False
None
with_kwargs bool If true, the hook will be passed the kwargs
given to the forward function.
Default: False
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_forward_pre_hook(

        self,

        hook: Union[

            Callable[[T, Tuple[Any, ...]], Optional[Any]],

            Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]],

        ],

        *,

        prepend: bool = False,

        with_kwargs: bool = False,

    ) -> RemovableHandle:

        r"""Register a forward pre-hook on the module.

        The hook will be called every time before :func:`forward` is invoked.

        If ``with_kwargs`` is false or not specified, the input contains only

        the positional arguments given to the module. Keyword arguments won't be

        passed to the hooks and only to the ``forward``. The hook can modify the

        input. User can either return a tuple or a single modified value in the

        hook. We will wrap the value into a tuple if a single value is returned

        (unless that value is already a tuple). The hook should have the

        following signature::

            hook(module, args) -> None or modified input

        If ``with_kwargs`` is true, the forward pre-hook will be passed the

        kwargs given to the forward function. And if the hook modifies the

        input, both the args and kwargs should be returned. The hook should have

        the following signature::

            hook(module, args, kwargs) -> None or a tuple of modified input and kwargs

        Args:

            hook (Callable): The user defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``forward_pre`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``forward_pre`` hooks

                on this :class:`torch.nn.modules.Module`. Note that global

                ``forward_pre`` hooks registered with

                :func:`register_module_forward_pre_hook` will fire before all

                hooks registered by this method.

                Default: ``False``

            with_kwargs (bool): If true, the ``hook`` will be passed the kwargs

                given to the forward function.

                Default: ``False``

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(

            self._forward_pre_hooks,

            extra_dict=self._forward_pre_hooks_with_kwargs

        )

        self._forward_pre_hooks[handle.id] = hook

        if with_kwargs:

            self._forward_pre_hooks_with_kwargs[handle.id] = True

        if prepend:

            self._forward_pre_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_full_backward_hook

def register_full_backward_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]],
    prepend: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a backward hook on the module.

The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature::

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The :attr:grad_input and :attr:grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:grad_input in subsequent computations. :attr:grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:grad_input and :attr:grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.

.. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.

Parameters:

Name Type Description Default
hook Callable The user-defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing backward hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this :class:torch.nn.modules.Module. Note that global
backward hooks registered with
:func:register_module_full_backward_hook will fire before
all hooks registered by this method.
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_full_backward_hook(

        self,

        hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]],

        prepend: bool = False,

    ) -> RemovableHandle:

        r"""Register a backward hook on the module.

        The hook will be called every time the gradients with respect to a module

        are computed, i.e. the hook will execute if and only if the gradients with

        respect to module outputs are computed. The hook should have the following

        signature::

            hook(module, grad_input, grad_output) -> tuple(Tensor) or None

        The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients

        with respect to the inputs and outputs respectively. The hook should

        not modify its arguments, but it can optionally return a new gradient with

        respect to the input that will be used in place of :attr:`grad_input` in

        subsequent computations. :attr:`grad_input` will only correspond to the inputs given

        as positional arguments and all kwarg arguments are ignored. Entries

        in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor

        arguments.

        For technical reasons, when this hook is applied to a Module, its forward function will

        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view

        of each Tensor returned by the Module's forward function.

        .. warning ::

            Modifying inputs or outputs inplace is not allowed when using backward hooks and

            will raise an error.

        Args:

            hook (Callable): The user-defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``backward`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``backward`` hooks on

                this :class:`torch.nn.modules.Module`. Note that global

                ``backward`` hooks registered with

                :func:`register_module_full_backward_hook` will fire before

                all hooks registered by this method.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        if self._is_full_backward_hook is False:

            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "

                               "single Module. Please use only one of them.")

        self._is_full_backward_hook = True

        handle = hooks.RemovableHandle(self._backward_hooks)

        self._backward_hooks[handle.id] = hook

        if prepend:

            self._backward_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_full_backward_pre_hook

def register_full_backward_pre_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]],
    prepend: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a backward pre-hook on the module.

The hook will be called every time the gradients for the module are computed. The hook should have the following signature::

hook(module, grad_output) -> tuple[Tensor] or None

The :attr:grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:grad_output in subsequent computations. Entries in :attr:grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.

.. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error.

Parameters:

Name Type Description Default
hook Callable The user-defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing backward_pre hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this :class:torch.nn.modules.Module. Note that global
backward_pre hooks registered with
:func:register_module_full_backward_pre_hook will fire before
all hooks registered by this method.
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_full_backward_pre_hook(

        self,

        hook: Callable[["Module", _grad_t], Union[None, _grad_t]],

        prepend: bool = False,

    ) -> RemovableHandle:

        r"""Register a backward pre-hook on the module.

        The hook will be called every time the gradients for the module are computed.

        The hook should have the following signature::

            hook(module, grad_output) -> tuple[Tensor] or None

        The :attr:`grad_output` is a tuple. The hook should

        not modify its arguments, but it can optionally return a new gradient with

        respect to the output that will be used in place of :attr:`grad_output` in

        subsequent computations. Entries in :attr:`grad_output` will be ``None`` for

        all non-Tensor arguments.

        For technical reasons, when this hook is applied to a Module, its forward function will

        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view

        of each Tensor returned by the Module's forward function.

        .. warning ::

            Modifying inputs inplace is not allowed when using backward hooks and

            will raise an error.

        Args:

            hook (Callable): The user-defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``backward_pre`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``backward_pre`` hooks

                on this :class:`torch.nn.modules.Module`. Note that global

                ``backward_pre`` hooks registered with

                :func:`register_module_full_backward_pre_hook` will fire before

                all hooks registered by this method.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(self._backward_pre_hooks)

        self._backward_pre_hooks[handle.id] = hook

        if prepend:

            self._backward_pre_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_load_state_dict_post_hook

def register_load_state_dict_post_hook(
    self,
    hook
)

Register a post hook to be run after module's load_state_dict is called.

It should have the following signature:: hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys.

The given incompatible_keys can be modified inplace if needed.

Note that the checks performed when calling :func:load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error.

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_load_state_dict_post_hook(self, hook):

        r"""Register a post hook to be run after module's ``load_state_dict`` is called.

        It should have the following signature::

            hook(module, incompatible_keys) -> None

        The ``module`` argument is the current module that this hook is registered

        on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting

        of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``

        is a ``list`` of ``str`` containing the missing keys and

        ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.

        The given incompatible_keys can be modified inplace if needed.

        Note that the checks performed when calling :func:`load_state_dict` with

        ``strict=True`` are affected by modifications the hook makes to

        ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either

        set of keys will result in an error being thrown when ``strict=True``, and

        clearing out both missing and unexpected keys will avoid an error.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(self._load_state_dict_post_hooks)

        self._load_state_dict_post_hooks[handle.id] = hook

        return handle

register_module

def register_module(
    self,
    name: str,
    module: Optional[ForwardRef('Module')]
) -> None

Alias for :func:add_module.

View Source
    def register_module(self, name: str, module: Optional['Module']) -> None:

        r"""Alias for :func:`add_module`."""

        self.add_module(name, module)

register_parameter

def register_parameter(
    self,
    name: str,
    param: Optional[torch.nn.parameter.Parameter]
) -> None

Add a parameter to the module.

The parameter can be accessed as an attribute using given name.

Parameters:

Name Type Description Default
name str name of the parameter. The parameter can be accessed
from this module using the given name
None
param Parameter or None parameter to be added to the module. If
None, then operations that run on parameters, such as :attr:cuda,
are ignored. If None, the parameter is not included in the
module's :attr:state_dict.
None
View Source
    def register_parameter(self, name: str, param: Optional[Parameter]) -> None:

        r"""Add a parameter to the module.

        The parameter can be accessed as an attribute using given name.

        Args:

            name (str): name of the parameter. The parameter can be accessed

                from this module using the given name

            param (Parameter or None): parameter to be added to the module. If

                ``None``, then operations that run on parameters, such as :attr:`cuda`,

                are ignored. If ``None``, the parameter is **not** included in the

                module's :attr:`state_dict`.

        """

        if '_parameters' not in self.__dict__:

            raise AttributeError(

                "cannot assign parameter before Module.__init__() call")

        elif not isinstance(name, str):

            raise TypeError(f"parameter name should be a string. Got {torch.typename(name)}")

        elif '.' in name:

            raise KeyError("parameter name can't contain \".\"")

        elif name == '':

            raise KeyError("parameter name can't be empty string \"\"")

        elif hasattr(self, name) and name not in self._parameters:

            raise KeyError(f"attribute '{name}' already exists")

        if param is None:

            self._parameters[name] = None

        elif not isinstance(param, Parameter):

            raise TypeError(f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "

                            "(torch.nn.Parameter or None required)"

                            )

        elif param.grad_fn:

            raise ValueError(

                f"Cannot assign non-leaf Tensor to parameter '{name}'. Model "

                f"parameters must be created explicitly. To express '{name}' "

                "as a function of another Tensor, compute the value in "

                "the forward() method.")

        else:

            for hook in _global_parameter_registration_hooks.values():

                output = hook(self, name, param)

                if output is not None:

                    param = output

            self._parameters[name] = param

register_state_dict_pre_hook

def register_state_dict_pre_hook(
    self,
    hook
)

Register a pre-hook for the :meth:~torch.nn.Module.state_dict method.

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.

View Source
    def register_state_dict_pre_hook(self, hook):

        r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.

        These hooks will be called with arguments: ``self``, ``prefix``,

        and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered

        hooks can be used to perform pre-processing before the ``state_dict``

        call is made.

        """

        handle = hooks.RemovableHandle(self._state_dict_pre_hooks)

        self._state_dict_pre_hooks[handle.id] = hook

        return handle

requires_grad_

def requires_grad_(
    self: ~T,
    requires_grad: bool = True
) -> ~T

Change if autograd should record operations on parameters in this module.

This method sets the parameters' :attr:requires_grad attributes in-place.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).

See :ref:locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it.

Parameters:

Name Type Description Default
requires_grad bool whether autograd should record operations on
parameters in this module. Default: True.
None

Returns:

Type Description
Module self
View Source
    def requires_grad_(self: T, requires_grad: bool = True) -> T:

        r"""Change if autograd should record operations on parameters in this module.

        This method sets the parameters' :attr:`requires_grad` attributes

        in-place.

        This method is helpful for freezing part of the module for finetuning

        or training parts of a model individually (e.g., GAN training).

        See :ref:`locally-disable-grad-doc` for a comparison between

        `.requires_grad_()` and several similar mechanisms that may be confused with it.

        Args:

            requires_grad (bool): whether autograd should record operations on

                                  parameters in this module. Default: ``True``.

        Returns:

            Module: self

        """

        for p in self.parameters():

            p.requires_grad_(requires_grad)

        return self

set_extra_state

def set_extra_state(
    self,
    state: Any
) -> None

Set extra state contained in the loaded state_dict.

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding

View Source
    def set_extra_state(self, state: Any) -> None:

        """Set extra state contained in the loaded `state_dict`.

        This function is called from :func:`load_state_dict` to handle any extra state

        found within the `state_dict`. Implement this function and a corresponding

        :func:`get_extra_state` for your module if you need to store extra state within its

        `state_dict`.

        Args:

            state (dict): Extra state from the `state_dict`

        """

        raise RuntimeError(

            "Reached a code path in Module.set_extra_state() that should never be called. "

            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

            "to report this bug.")

share_memory

def share_memory(
    self: ~T
) -> ~T

See :meth:torch.Tensor.share_memory_.

View Source
    def share_memory(self: T) -> T:

        r"""See :meth:`torch.Tensor.share_memory_`."""

        return self._apply(lambda t: t.share_memory_())

state_dict

def state_dict(
    self,
    *args,
    destination=None,
    prefix='',
    keep_vars=False
)

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

.. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers.

.. warning:: Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

.. warning:: Please avoid the use of argument destination as it is not designed for end-users.

Parameters:

Name Type Description Default
destination dict If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
None
prefix str a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
None
keep_vars bool by default the :class:~torch.Tensor s
returned in the state dict are detached from autograd. If it's
set to True, detaching will not be performed.
Default: False.
None

Returns:

Type Description
dict a dictionary containing a whole state of the module
View Source
    def state_dict(self, *args, destination=None, prefix='', keep_vars=False):

        r"""Return a dictionary containing references to the whole state of the module.

        Both parameters and persistent buffers (e.g. running averages) are

        included. Keys are corresponding parameter and buffer names.

        Parameters and buffers set to ``None`` are not included.

        .. note::

            The returned object is a shallow copy. It contains references

            to the module's parameters and buffers.

        .. warning::

            Currently ``state_dict()`` also accepts positional arguments for

            ``destination``, ``prefix`` and ``keep_vars`` in order. However,

            this is being deprecated and keyword arguments will be enforced in

            future releases.

        .. warning::

            Please avoid the use of argument ``destination`` as it is not

            designed for end-users.

        Args:

            destination (dict, optional): If provided, the state of module will

                be updated into the dict and the same object is returned.

                Otherwise, an ``OrderedDict`` will be created and returned.

                Default: ``None``.

            prefix (str, optional): a prefix added to parameter and buffer

                names to compose the keys in state_dict. Default: ``''``.

            keep_vars (bool, optional): by default the :class:`~torch.Tensor` s

                returned in the state dict are detached from autograd. If it's

                set to ``True``, detaching will not be performed.

                Default: ``False``.

        Returns:

            dict:

                a dictionary containing a whole state of the module

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> module.state_dict().keys()

            ['bias', 'weight']

        """

        # TODO: Remove `args` and the parsing logic when BC allows.

        if len(args) > 0:

            if destination is None:

                destination = args[0]

            if len(args) > 1 and prefix == '':

                prefix = args[1]

            if len(args) > 2 and keep_vars is False:

                keep_vars = args[2]

            # DeprecationWarning is ignored by default

            warnings.warn(

                "Positional args are being deprecated, use kwargs instead. Refer to "

                "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"

                " for details.")

        if destination is None:

            destination = OrderedDict()

            destination._metadata = OrderedDict()

        local_metadata = dict(version=self._version)

        if hasattr(destination, "_metadata"):

            destination._metadata[prefix[:-1]] = local_metadata

        for hook in self._state_dict_pre_hooks.values():

            hook(self, prefix, keep_vars)

        self._save_to_state_dict(destination, prefix, keep_vars)

        for name, module in self._modules.items():

            if module is not None:

                module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)

        for hook in self._state_dict_hooks.values():

            hook_result = hook(self, destination, prefix, local_metadata)

            if hook_result is not None:

                destination = hook_result

        return destination

to

def to(
    self,
    *args,
    **kwargs
)

Move and/or cast the parameters and buffers.

This can be called as

.. function:: to(device=None, dtype=None, non_blocking=False) :noindex:

.. function:: to(dtype, non_blocking=False) :noindex:

.. function:: to(tensor, non_blocking=False) :noindex:

.. function:: to(memory_format=torch.channels_last) :noindex:

Its signature is similar to :meth:torch.Tensor.to, but only accepts floating point or complex :attr:dtype\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:dtype (if given). The integral parameters and buffers will be moved :attr:device, if that is given, but with dtypes unchanged. When :attr:non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device ( None class:torch.device): the desired device of the parameters
and buffers in this module
None
dtype ( None class:torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
None
tensor torch.Tensor Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
None
memory_format ( None class:torch.memory_format): the desired memory
format for 4D parameters and buffers in this module (keyword
only argument)
None

Returns:

Type Description
Module self
View Source
    def to(self, *args, **kwargs):

        r"""Move and/or cast the parameters and buffers.

        This can be called as

        .. function:: to(device=None, dtype=None, non_blocking=False)

           :noindex:

        .. function:: to(dtype, non_blocking=False)

           :noindex:

        .. function:: to(tensor, non_blocking=False)

           :noindex:

        .. function:: to(memory_format=torch.channels_last)

           :noindex:

        Its signature is similar to :meth:`torch.Tensor.to`, but only accepts

        floating point or complex :attr:`dtype`\ s. In addition, this method will

        only cast the floating point or complex parameters and buffers to :attr:`dtype`

        (if given). The integral parameters and buffers will be moved

        :attr:`device`, if that is given, but with dtypes unchanged. When

        :attr:`non_blocking` is set, it tries to convert/move asynchronously

        with respect to the host if possible, e.g., moving CPU Tensors with

        pinned memory to CUDA devices.

        See below for examples.

        .. note::

            This method modifies the module in-place.

        Args:

            device (:class:`torch.device`): the desired device of the parameters

                and buffers in this module

            dtype (:class:`torch.dtype`): the desired floating point or complex dtype of

                the parameters and buffers in this module

            tensor (torch.Tensor): Tensor whose dtype and device are the desired

                dtype and device for all parameters and buffers in this module

            memory_format (:class:`torch.memory_format`): the desired memory

                format for 4D parameters and buffers in this module (keyword

                only argument)

        Returns:

            Module: self

        Examples::

            >>> # xdoctest: +IGNORE_WANT("non-deterministic")

            >>> linear = nn.Linear(2, 2)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1913, -0.3420],

                    [-0.5113, -0.2325]])

            >>> linear.to(torch.double)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1913, -0.3420],

                    [-0.5113, -0.2325]], dtype=torch.float64)

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)

            >>> gpu1 = torch.device("cuda:1")

            >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1914, -0.3420],

                    [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')

            >>> cpu = torch.device("cpu")

            >>> linear.to(cpu)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1914, -0.3420],

                    [-0.5112, -0.2324]], dtype=torch.float16)

            >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.3741+0.j,  0.2382+0.j],

                    [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)

            >>> linear(torch.ones(3, 2, dtype=torch.cdouble))

            tensor([[0.6122+0.j, 0.1150+0.j],

                    [0.6122+0.j, 0.1150+0.j],

                    [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)

        """

        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

        if dtype is not None:

            if not (dtype.is_floating_point or dtype.is_complex):

                raise TypeError('nn.Module.to only accepts floating point or complex '

                                f'dtypes, but got desired dtype={dtype}')

            if dtype.is_complex:

                warnings.warn(

                    "Complex modules are a new feature under active development whose design may change, "

                    "and some modules might not work as expected when using complex tensors as parameters or buffers. "

                    "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

                    "if a complex module does not work as expected.")

        def convert(t):

            try:

                if convert_to_format is not None and t.dim() in (4, 5):

                    return t.to(

                        device,

                        dtype if t.is_floating_point() or t.is_complex() else None,

                        non_blocking,

                        memory_format=convert_to_format,

                    )

                return t.to(

                    device,

                    dtype if t.is_floating_point() or t.is_complex() else None,

                    non_blocking,

                )

            except NotImplementedError as e:

                if str(e) == "Cannot copy out of meta tensor; no data!":

                    raise NotImplementedError(

                        f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "

                        f"when moving module from meta to a different device."

                    ) from None

                else:

                    raise

        return self._apply(convert)

to_empty

def to_empty(
    self: ~T,
    *,
    device: Union[int, str, torch.device, NoneType],
    recurse: bool = True
) -> ~T

Move the parameters and buffers to the specified device without copying storage.

Parameters:

Name Type Description Default
device ( None class:torch.device): The desired device of the parameters
and buffers in this module.
None
recurse bool Whether parameters and buffers of submodules should
be recursively moved to the specified device.
None

Returns:

Type Description
Module self
View Source
    def to_empty(self: T, *, device: Optional[DeviceLikeType], recurse: bool = True) -> T:

        r"""Move the parameters and buffers to the specified device without copying storage.

        Args:

            device (:class:`torch.device`): The desired device of the parameters

                and buffers in this module.

            recurse (bool): Whether parameters and buffers of submodules should

                be recursively moved to the specified device.

        Returns:

            Module: self

        """

        return self._apply(lambda t: torch.empty_like(t, device=device), recurse=recurse)

train

def train(
    self: ~T,
    mode: bool = True
) -> ~T

Set the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

Parameters:

Name Type Description Default
mode bool whether to set training mode (True) or evaluation
mode (False). Default: True.
None

Returns:

Type Description
Module self
View Source
    def train(self: T, mode: bool = True) -> T:

        r"""Set the module in training mode.

        This has any effect only on certain modules. See documentations of

        particular modules for details of their behaviors in training/evaluation

        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,

        etc.

        Args:

            mode (bool): whether to set training mode (``True``) or evaluation

                         mode (``False``). Default: ``True``.

        Returns:

            Module: self

        """

        if not isinstance(mode, bool):

            raise ValueError("training mode is expected to be boolean")

        self.training = mode

        for module in self.children():

            module.train(mode)

        return self

type

def type(
    self: ~T,
    dst_type: Union[torch.dtype, str]
) -> ~T

Casts all parameters and buffers to :attr:dst_type.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
dst_type type or string the desired type None

Returns:

Type Description
Module self
View Source
    def type(self: T, dst_type: Union[dtype, str]) -> T:

        r"""Casts all parameters and buffers to :attr:`dst_type`.

        .. note::

            This method modifies the module in-place.

        Args:

            dst_type (type or string): the desired type

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.type(dst_type))

xpu

def xpu(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the XPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the XPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on XPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Arguments:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.xpu(device))

zero_grad

def zero_grad(
    self,
    set_to_none: bool = True
) -> None

Reset gradients of all model parameters.

See similar function under :class:torch.optim.Optimizer for more context.

Parameters:

Name Type Description Default
set_to_none bool instead of setting to zero, set the grads to None.
See :meth:torch.optim.Optimizer.zero_grad for details.
None
View Source
    def zero_grad(self, set_to_none: bool = True) -> None:

        r"""Reset gradients of all model parameters.

        See similar function under :class:`torch.optim.Optimizer` for more context.

        Args:

            set_to_none (bool): instead of setting to zero, set the grads to None.

                See :meth:`torch.optim.Optimizer.zero_grad` for details.

        """

        if getattr(self, '_is_replica', False):

            warnings.warn(

                "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "

                "The parameters are copied (in a differentiable manner) from the original module. "

                "This means they are not leaf nodes in autograd and so don't accumulate gradients. "

                "If you need gradients in your forward method, consider using autograd.grad instead.")

        for p in self.parameters():

            if p.grad is not None:

                if set_to_none:

                    p.grad = None

                else:

                    if p.grad.grad_fn is not None:

                        p.grad.detach_()

                    else:

                        p.grad.requires_grad_(False)

                    p.grad.zero_()

WormPredictor

class WormPredictor(
    model: torch.nn.modules.module.Module,
    io_config: wtracker.neural.config.IOConfig
)

A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class

so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model). This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output.

Attributes

Name Type Description Default
model None The neural network model that predicts worm behavior. None
io_config None The IOConfig object of the model. None
View Source
class WormPredictor(nn.Module):

    """

    A class that represents neural network models that predict worm behavior. After a model is created from several layers or blocks, it is wrapped in this class

    so that it can be distinguished from other models that don't predict worm behavior (for example the layers/blocks that make this model).

    This class also holds the IOConfig object that is used to determine the input and output shapes of the model, and the specific frames it expects as input and output.

    Attributes:

        model: The neural network model that predicts worm behavior.

        io_config: The IOConfig object of the model.

    """

    def __init__(self, model: nn.Module, io_config: IOConfig):

        super().__init__()

        self.io_config: IOConfig = io_config

        self.model: nn.Module = model

    def forward(self, x: Tensor) -> Tensor:

        return self.model(x)

Ancestors (in MRO)

  • torch.nn.modules.module.Module

Class variables

T_destination
call_super_init
dump_patches

Methods

add_module

def add_module(
    self,
    name: str,
    module: Optional[ForwardRef('Module')]
) -> None

Add a child module to the current module.

The module can be accessed as an attribute using the given name.

Parameters:

Name Type Description Default
name str name of the child module. The child module can be
accessed from this module using the given name
None
module Module child module to be added to the module. None
View Source
    def add_module(self, name: str, module: Optional['Module']) -> None:

        r"""Add a child module to the current module.

        The module can be accessed as an attribute using the given name.

        Args:

            name (str): name of the child module. The child module can be

                accessed from this module using the given name

            module (Module): child module to be added to the module.

        """

        if not isinstance(module, Module) and module is not None:

            raise TypeError(f"{torch.typename(module)} is not a Module subclass")

        elif not isinstance(name, str):

            raise TypeError(f"module name should be a string. Got {torch.typename(name)}")

        elif hasattr(self, name) and name not in self._modules:

            raise KeyError(f"attribute '{name}' already exists")

        elif '.' in name:

            raise KeyError(f"module name can't contain \".\", got: {name}")

        elif name == '':

            raise KeyError("module name can't be empty string \"\"")

        for hook in _global_module_registration_hooks.values():

            output = hook(self, name, module)

            if output is not None:

                module = output

        self._modules[name] = module

apply

def apply(
    self: ~T,
    fn: Callable[[ForwardRef('Module')], NoneType]
) -> ~T

Apply fn recursively to every submodule (as returned by .children()) as well as self.

Typical use includes initializing the parameters of a model (see also :ref:nn-init-doc).

Parameters:

Name Type Description Default
fn ( None class:Module -> None): function to be applied to each submodule None

Returns:

Type Description
Module self
View Source
    def apply(self: T, fn: Callable[['Module'], None]) -> T:

        r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.

        Typical use includes initializing the parameters of a model

        (see also :ref:`nn-init-doc`).

        Args:

            fn (:class:`Module` -> None): function to be applied to each submodule

        Returns:

            Module: self

        Example::

            >>> @torch.no_grad()

            >>> def init_weights(m):

            >>>     print(m)

            >>>     if type(m) == nn.Linear:

            >>>         m.weight.fill_(1.0)

            >>>         print(m.weight)

            >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

            >>> net.apply(init_weights)

            Linear(in_features=2, out_features=2, bias=True)

            Parameter containing:

            tensor([[1., 1.],

                    [1., 1.]], requires_grad=True)

            Linear(in_features=2, out_features=2, bias=True)

            Parameter containing:

            tensor([[1., 1.],

                    [1., 1.]], requires_grad=True)

            Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            )

        """

        for module in self.children():

            module.apply(fn)

        fn(self)

        return self

bfloat16

def bfloat16(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to bfloat16 datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def bfloat16(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)

buffers

def buffers(
    self,
    recurse: bool = True
) -> Iterator[torch.Tensor]

Return an iterator over module buffers.

Parameters:

Name Type Description Default
recurse bool if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
None

Yields:

Type Description
torch.Tensor module buffer
View Source
    def buffers(self, recurse: bool = True) -> Iterator[Tensor]:

        r"""Return an iterator over module buffers.

        Args:

            recurse (bool): if True, then yields buffers of this module

                and all submodules. Otherwise, yields only buffers that

                are direct members of this module.

        Yields:

            torch.Tensor: module buffer

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for buf in model.buffers():

            >>>     print(type(buf), buf.size())

            <class 'torch.Tensor'> (20L,)

            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

        """

        for _, buf in self.named_buffers(recurse=recurse):

            yield buf

children

def children(
    self
) -> Iterator[ForwardRef('Module')]

Return an iterator over immediate children modules.

Yields:

Type Description
Module a child module
View Source
    def children(self) -> Iterator['Module']:

        r"""Return an iterator over immediate children modules.

        Yields:

            Module: a child module

        """

        for name, module in self.named_children():

            yield module

compile

def compile(
    self,
    *args,
    **kwargs
)

Compile this Module's forward using :func:torch.compile.

This Module's __call__ method is compiled and all arguments are passed as-is to :func:torch.compile.

See :func:torch.compile for details on the arguments for this function.

View Source
    def compile(self, *args, **kwargs):

        """

        Compile this Module's forward using :func:`torch.compile`.

        This Module's `__call__` method is compiled and all arguments are passed as-is

        to :func:`torch.compile`.

        See :func:`torch.compile` for details on the arguments for this function.

        """

        self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs)

cpu

def cpu(
    self: ~T
) -> ~T

Move all model parameters and buffers to the CPU.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def cpu(self: T) -> T:

        r"""Move all model parameters and buffers to the CPU.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.cpu())

cuda

def cuda(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the GPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on GPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Args:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.cuda(device))

double

def double(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to double datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def double(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``double`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.double() if t.is_floating_point() else t)

eval

def eval(
    self: ~T
) -> ~T

Set the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

This is equivalent with :meth:self.train(False) <torch.nn.Module.train>.

See :ref:locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it.

Returns:

Type Description
Module self
View Source
    def eval(self: T) -> T:

        r"""Set the module in evaluation mode.

        This has any effect only on certain modules. See documentations of

        particular modules for details of their behaviors in training/evaluation

        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,

        etc.

        This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.

        See :ref:`locally-disable-grad-doc` for a comparison between

        `.eval()` and several similar mechanisms that may be confused with it.

        Returns:

            Module: self

        """

        return self.train(False)

extra_repr

def extra_repr(
    self
) -> str

Set the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

View Source
    def extra_repr(self) -> str:

        r"""Set the extra representation of the module.

        To print customized extra information, you should re-implement

        this method in your own modules. Both single-line and multi-line

        strings are acceptable.

        """

        return ''

float

def float(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to float datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def float(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``float`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.float() if t.is_floating_point() else t)

forward

def forward(
    self,
    x: torch.Tensor
) -> torch.Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

View Source
    def forward(self, x: Tensor) -> Tensor:

        return self.model(x)

get_buffer

def get_buffer(
    self,
    target: str
) -> 'Tensor'

Return the buffer given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the buffer
to look for. (See get_submodule for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.Tensor The buffer referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not a
buffer
View Source
    def get_buffer(self, target: str) -> "Tensor":

        """Return the buffer given by ``target`` if it exists, otherwise throw an error.

        See the docstring for ``get_submodule`` for a more detailed

        explanation of this method's functionality as well as how to

        correctly specify ``target``.

        Args:

            target: The fully-qualified string name of the buffer

                to look for. (See ``get_submodule`` for how to specify a

                fully-qualified string.)

        Returns:

            torch.Tensor: The buffer referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not a

                buffer

        """

        module_path, _, buffer_name = target.rpartition(".")

        mod: torch.nn.Module = self.get_submodule(module_path)

        if not hasattr(mod, buffer_name):

            raise AttributeError(mod._get_name() + " has no attribute `"

                                 + buffer_name + "`")

        buffer: torch.Tensor = getattr(mod, buffer_name)

        if buffer_name not in mod._buffers:

            raise AttributeError("`" + buffer_name + "` is not a buffer")

        return buffer

get_extra_state

def get_extra_state(
    self
) -> Any

Return any extra state to include in the module's state_dict.

Implement this and a corresponding :func:set_extra_state for your module if you need to store extra state. This function is called when building the module's state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns:

Type Description
object Any extra state to store in the module's state_dict
View Source
    def get_extra_state(self) -> Any:

        """Return any extra state to include in the module's state_dict.

        Implement this and a corresponding :func:`set_extra_state` for your module

        if you need to store extra state. This function is called when building the

        module's `state_dict()`.

        Note that extra state should be picklable to ensure working serialization

        of the state_dict. We only provide provide backwards compatibility guarantees

        for serializing Tensors; other objects may break backwards compatibility if

        their serialized pickled form changes.

        Returns:

            object: Any extra state to store in the module's state_dict

        """

        raise RuntimeError(

            "Reached a code path in Module.get_extra_state() that should never be called. "

            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

            "to report this bug.")

get_parameter

def get_parameter(
    self,
    target: str
) -> 'Parameter'

Return the parameter given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method's functionality as well as how to correctly specify target.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the Parameter
to look for. (See get_submodule for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.nn.Parameter The Parameter referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not an
nn.Parameter
View Source
    def get_parameter(self, target: str) -> "Parameter":

        """Return the parameter given by ``target`` if it exists, otherwise throw an error.

        See the docstring for ``get_submodule`` for a more detailed

        explanation of this method's functionality as well as how to

        correctly specify ``target``.

        Args:

            target: The fully-qualified string name of the Parameter

                to look for. (See ``get_submodule`` for how to specify a

                fully-qualified string.)

        Returns:

            torch.nn.Parameter: The Parameter referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not an

                ``nn.Parameter``

        """

        module_path, _, param_name = target.rpartition(".")

        mod: torch.nn.Module = self.get_submodule(module_path)

        if not hasattr(mod, param_name):

            raise AttributeError(mod._get_name() + " has no attribute `"

                                 + param_name + "`")

        param: torch.nn.Parameter = getattr(mod, param_name)

        if not isinstance(param, torch.nn.Parameter):

            raise AttributeError("`" + param_name + "` is not an "

                                 "nn.Parameter")

        return param

get_submodule

def get_submodule(
    self,
    target: str
) -> 'Module'

Return the submodule given by target if it exists, otherwise throw an error.

For example, let's say you have an nn.Module A that looks like this:

.. code-block:: text

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(The diagram shows an nn.Module A. A has a nested submodule net_b, which itself has two submodules net_c and linear. net_c then has a submodule conv.)

To check whether or not we have the linear submodule, we would call get_submodule("net_b.linear"). To check whether we have the conv submodule, we would call get_submodule("net_b.net_c.conv").

The runtime of get_submodule is bounded by the degree of module nesting in target. A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used.

Parameters:

Name Type Description Default
target None The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
None

Returns:

Type Description
torch.nn.Module The submodule referenced by target

Raises:

Type Description
AttributeError If the target string references an invalid
path or resolves to something that is not an
nn.Module
View Source
    def get_submodule(self, target: str) -> "Module":

        """Return the submodule given by ``target`` if it exists, otherwise throw an error.

        For example, let's say you have an ``nn.Module`` ``A`` that

        looks like this:

        .. code-block:: text

            A(

                (net_b): Module(

                    (net_c): Module(

                        (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))

                    )

                    (linear): Linear(in_features=100, out_features=200, bias=True)

                )

            )

        (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested

        submodule ``net_b``, which itself has two submodules ``net_c``

        and ``linear``. ``net_c`` then has a submodule ``conv``.)

        To check whether or not we have the ``linear`` submodule, we

        would call ``get_submodule("net_b.linear")``. To check whether

        we have the ``conv`` submodule, we would call

        ``get_submodule("net_b.net_c.conv")``.

        The runtime of ``get_submodule`` is bounded by the degree

        of module nesting in ``target``. A query against

        ``named_modules`` achieves the same result, but it is O(N) in

        the number of transitive modules. So, for a simple check to see

        if some submodule exists, ``get_submodule`` should always be

        used.

        Args:

            target: The fully-qualified string name of the submodule

                to look for. (See above example for how to specify a

                fully-qualified string.)

        Returns:

            torch.nn.Module: The submodule referenced by ``target``

        Raises:

            AttributeError: If the target string references an invalid

                path or resolves to something that is not an

                ``nn.Module``

        """

        if target == "":

            return self

        atoms: List[str] = target.split(".")

        mod: torch.nn.Module = self

        for item in atoms:

            if not hasattr(mod, item):

                raise AttributeError(mod._get_name() + " has no "

                                     "attribute `" + item + "`")

            mod = getattr(mod, item)

            if not isinstance(mod, torch.nn.Module):

                raise AttributeError("`" + item + "` is not "

                                     "an nn.Module")

        return mod

half

def half(
    self: ~T
) -> ~T

Casts all floating point parameters and buffers to half datatype.

.. note:: This method modifies the module in-place.

Returns:

Type Description
Module self
View Source
    def half(self: T) -> T:

        r"""Casts all floating point parameters and buffers to ``half`` datatype.

        .. note::

            This method modifies the module in-place.

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.half() if t.is_floating_point() else t)

ipu

def ipu(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the IPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the IPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on IPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Arguments:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.ipu(device))

load_state_dict

def load_state_dict(
    self,
    state_dict: Mapping[str, Any],
    strict: bool = True,
    assign: bool = False
)

Copy parameters and buffers from :attr:state_dict into this module and its descendants.

If :attr:strict is True, then the keys of :attr:state_dict must exactly match the keys returned by this module's :meth:~torch.nn.Module.state_dict function.

.. warning:: If :attr:assign is True the optimizer must be created after the call to :attr:load_state_dict unless :func:~torch.__future__.get_swap_module_params_on_conversion is True.

Parameters:

Name Type Description Default
state_dict dict a dict containing parameters and
persistent buffers.
None
strict bool whether to strictly enforce that the keys
in :attr:state_dict match the keys returned by this module's
:meth:~torch.nn.Module.state_dict function. Default: True
None
assign bool When False, the properties of the tensors
in the current module are preserved while when True, the
properties of the Tensors in the state dict are preserved. The only
exception is the requires_grad field of :class:~torch.nn.Parameters
for which the value from the module is preserved.
Default: False
None

Returns:

Type Description
None NamedTuple with missing_keys and unexpected_keys fields:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
View Source
    def load_state_dict(self, state_dict: Mapping[str, Any],

                        strict: bool = True, assign: bool = False):

        r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

        If :attr:`strict` is ``True``, then

        the keys of :attr:`state_dict` must exactly match the keys returned

        by this module's :meth:`~torch.nn.Module.state_dict` function.

        .. warning::

            If :attr:`assign` is ``True`` the optimizer must be created after

            the call to :attr:`load_state_dict` unless

            :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.

        Args:

            state_dict (dict): a dict containing parameters and

                persistent buffers.

            strict (bool, optional): whether to strictly enforce that the keys

                in :attr:`state_dict` match the keys returned by this module's

                :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

            assign (bool, optional): When ``False``, the properties of the tensors

                in the current module are preserved while when ``True``, the

                properties of the Tensors in the state dict are preserved. The only

                exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s

                for which the value from the module is preserved.

                Default: ``False``

        Returns:

            ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:

                * **missing_keys** is a list of str containing the missing keys

                * **unexpected_keys** is a list of str containing the unexpected keys

        Note:

            If a parameter or buffer is registered as ``None`` and its corresponding key

            exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a

            ``RuntimeError``.

        """

        if not isinstance(state_dict, Mapping):

            raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")

        missing_keys: List[str] = []

        unexpected_keys: List[str] = []

        error_msgs: List[str] = []

        # copy state_dict so _load_from_state_dict can modify it

        metadata = getattr(state_dict, '_metadata', None)

        state_dict = OrderedDict(state_dict)

        if metadata is not None:

            # mypy isn't aware that "_metadata" exists in state_dict

            state_dict._metadata = metadata  # type: ignore[attr-defined]

        def load(module, local_state_dict, prefix=''):

            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})

            if assign:

                local_metadata['assign_to_params_buffers'] = assign

            module._load_from_state_dict(

                local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)

            for name, child in module._modules.items():

                if child is not None:

                    child_prefix = prefix + name + '.'

                    child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}

                    load(child, child_state_dict, child_prefix)  # noqa: F821

            # Note that the hook can modify missing_keys and unexpected_keys.

            incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

            for hook in module._load_state_dict_post_hooks.values():

                out = hook(module, incompatible_keys)

                assert out is None, (

                    "Hooks registered with ``register_load_state_dict_post_hook`` are not"

                    "expected to return new values, if incompatible_keys need to be modified,"

                    "it should be done inplace."

                )

        load(self, state_dict)

        del load

        if strict:

            if len(unexpected_keys) > 0:

                error_msgs.insert(

                    0, 'Unexpected key(s) in state_dict: {}. '.format(

                        ', '.join(f'"{k}"' for k in unexpected_keys)))

            if len(missing_keys) > 0:

                error_msgs.insert(

                    0, 'Missing key(s) in state_dict: {}. '.format(

                        ', '.join(f'"{k}"' for k in missing_keys)))

        if len(error_msgs) > 0:

            raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(

                               self.__class__.__name__, "\n\t".join(error_msgs)))

        return _IncompatibleKeys(missing_keys, unexpected_keys)

modules

def modules(
    self
) -> Iterator[ForwardRef('Module')]

Return an iterator over all modules in the network.

Yields:

Type Description
Module a module in the network
View Source
    def modules(self) -> Iterator['Module']:

        r"""Return an iterator over all modules in the network.

        Yields:

            Module: a module in the network

        Note:

            Duplicate modules are returned only once. In the following

            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)

            >>> net = nn.Sequential(l, l)

            >>> for idx, m in enumerate(net.modules()):

            ...     print(idx, '->', m)

            0 -> Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            )

            1 -> Linear(in_features=2, out_features=2, bias=True)

        """

        for _, module in self.named_modules():

            yield module

named_buffers

def named_buffers(
    self,
    prefix: str = '',
    recurse: bool = True,
    remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Parameters:

Name Type Description Default
prefix str prefix to prepend to all buffer names. None
recurse bool if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module. Defaults to True.
None
remove_duplicate bool whether to remove the duplicated buffers in the result. Defaults to True. True

Yields:

Type Description
None (str, torch.Tensor): Tuple containing the name and buffer
View Source
    def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:

        r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

        Args:

            prefix (str): prefix to prepend to all buffer names.

            recurse (bool, optional): if True, then yields buffers of this module

                and all submodules. Otherwise, yields only buffers that

                are direct members of this module. Defaults to True.

            remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.

        Yields:

            (str, torch.Tensor): Tuple containing the name and buffer

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, buf in self.named_buffers():

            >>>     if name in ['running_var']:

            >>>         print(buf.size())

        """

        gen = self._named_members(

            lambda module: module._buffers.items(),

            prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)

        yield from gen

named_children

def named_children(
    self
) -> Iterator[Tuple[str, ForwardRef('Module')]]

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields:

Type Description
None (str, Module): Tuple containing a name and child module
View Source
    def named_children(self) -> Iterator[Tuple[str, 'Module']]:

        r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

        Yields:

            (str, Module): Tuple containing a name and child module

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, module in model.named_children():

            >>>     if name in ['conv4', 'conv5']:

            >>>         print(module)

        """

        memo = set()

        for name, module in self._modules.items():

            if module is not None and module not in memo:

                memo.add(module)

                yield name, module

named_modules

def named_modules(
    self,
    memo: Optional[Set[ForwardRef('Module')]] = None,
    prefix: str = '',
    remove_duplicate: bool = True
)

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Parameters:

Name Type Description Default
memo None a memo to store the set of modules already added to the result None
prefix None a prefix that will be added to the name of the module None
remove_duplicate None whether to remove the duplicated module instances in the result
or not
None

Yields:

Type Description
None (str, Module): Tuple of name and module
View Source
    def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):

        r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

        Args:

            memo: a memo to store the set of modules already added to the result

            prefix: a prefix that will be added to the name of the module

            remove_duplicate: whether to remove the duplicated module instances in the result

                or not

        Yields:

            (str, Module): Tuple of name and module

        Note:

            Duplicate modules are returned only once. In the following

            example, ``l`` will be returned only once.

        Example::

            >>> l = nn.Linear(2, 2)

            >>> net = nn.Sequential(l, l)

            >>> for idx, m in enumerate(net.named_modules()):

            ...     print(idx, '->', m)

            0 -> ('', Sequential(

              (0): Linear(in_features=2, out_features=2, bias=True)

              (1): Linear(in_features=2, out_features=2, bias=True)

            ))

            1 -> ('0', Linear(in_features=2, out_features=2, bias=True))

        """

        if memo is None:

            memo = set()

        if self not in memo:

            if remove_duplicate:

                memo.add(self)

            yield prefix, self

            for name, module in self._modules.items():

                if module is None:

                    continue

                submodule_prefix = prefix + ('.' if prefix else '') + name

                yield from module.named_modules(memo, submodule_prefix, remove_duplicate)

named_parameters

def named_parameters(
    self,
    prefix: str = '',
    recurse: bool = True,
    remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:

Name Type Description Default
prefix str prefix to prepend to all parameter names. None
recurse bool if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
None
remove_duplicate bool whether to remove the duplicated
parameters in the result. Defaults to True.
None

Yields:

Type Description
None (str, Parameter): Tuple containing the name and parameter
View Source
    def named_parameters(

            self,

            prefix: str = '',

            recurse: bool = True,

            remove_duplicate: bool = True

    ) -> Iterator[Tuple[str, Parameter]]:

        r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

        Args:

            prefix (str): prefix to prepend to all parameter names.

            recurse (bool): if True, then yields parameters of this module

                and all submodules. Otherwise, yields only parameters that

                are direct members of this module.

            remove_duplicate (bool, optional): whether to remove the duplicated

                parameters in the result. Defaults to True.

        Yields:

            (str, Parameter): Tuple containing the name and parameter

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for name, param in self.named_parameters():

            >>>     if name in ['bias']:

            >>>         print(param.size())

        """

        gen = self._named_members(

            lambda module: module._parameters.items(),

            prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)

        yield from gen

parameters

def parameters(
    self,
    recurse: bool = True
) -> Iterator[torch.nn.parameter.Parameter]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Parameters:

Name Type Description Default
recurse bool if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
None

Yields:

Type Description
Parameter module parameter
View Source
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:

        r"""Return an iterator over module parameters.

        This is typically passed to an optimizer.

        Args:

            recurse (bool): if True, then yields parameters of this module

                and all submodules. Otherwise, yields only parameters that

                are direct members of this module.

        Yields:

            Parameter: module parameter

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> for param in model.parameters():

            >>>     print(type(param), param.size())

            <class 'torch.Tensor'> (20L,)

            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)

        """

        for name, param in self.named_parameters(recurse=recurse):

            yield param

register_backward_hook

def register_backward_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]
) -> torch.utils.hooks.RemovableHandle

Register a backward hook on the module.

This function is deprecated in favor of :meth:~torch.nn.Module.register_full_backward_hook and the behavior of this function will change in future versions.

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_backward_hook(

        self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]]

    ) -> RemovableHandle:

        r"""Register a backward hook on the module.

        This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and

        the behavior of this function will change in future versions.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        if self._is_full_backward_hook is True:

            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "

                               "single Module. Please use only one of them.")

        self._is_full_backward_hook = False

        handle = hooks.RemovableHandle(self._backward_hooks)

        self._backward_hooks[handle.id] = hook

        return handle

register_buffer

def register_buffer(
    self,
    name: str,
    tensor: Optional[torch.Tensor],
    persistent: bool = True
) -> None

Add a buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's running_mean is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:state_dict.

Buffers can be accessed as attributes using given names.

Parameters:

Name Type Description Default
name str name of the buffer. The buffer can be accessed
from this module using the given name
None
tensor Tensor or None buffer to be registered. If None, then operations
that run on buffers, such as :attr:cuda, are ignored. If None,
the buffer is not included in the module's :attr:state_dict.
None
persistent bool whether the buffer is part of this module's
:attr:state_dict.
None
View Source
    def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:

        r"""Add a buffer to the module.

        This is typically used to register a buffer that should not to be

        considered a model parameter. For example, BatchNorm's ``running_mean``

        is not a parameter, but is part of the module's state. Buffers, by

        default, are persistent and will be saved alongside parameters. This

        behavior can be changed by setting :attr:`persistent` to ``False``. The

        only difference between a persistent buffer and a non-persistent buffer

        is that the latter will not be a part of this module's

        :attr:`state_dict`.

        Buffers can be accessed as attributes using given names.

        Args:

            name (str): name of the buffer. The buffer can be accessed

                from this module using the given name

            tensor (Tensor or None): buffer to be registered. If ``None``, then operations

                that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,

                the buffer is **not** included in the module's :attr:`state_dict`.

            persistent (bool): whether the buffer is part of this module's

                :attr:`state_dict`.

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> self.register_buffer('running_mean', torch.zeros(num_features))

        """

        if persistent is False and isinstance(self, torch.jit.ScriptModule):

            raise RuntimeError("ScriptModule does not support non-persistent buffers")

        if '_buffers' not in self.__dict__:

            raise AttributeError(

                "cannot assign buffer before Module.__init__() call")

        elif not isinstance(name, str):

            raise TypeError(f"buffer name should be a string. Got {torch.typename(name)}")

        elif '.' in name:

            raise KeyError("buffer name can't contain \".\"")

        elif name == '':

            raise KeyError("buffer name can't be empty string \"\"")

        elif hasattr(self, name) and name not in self._buffers:

            raise KeyError(f"attribute '{name}' already exists")

        elif tensor is not None and not isinstance(tensor, torch.Tensor):

            raise TypeError(f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' "

                            "(torch Tensor or None required)"

                            )

        else:

            for hook in _global_buffer_registration_hooks.values():

                output = hook(self, name, tensor)

                if output is not None:

                    tensor = output

            self._buffers[name] = tensor

            if persistent:

                self._non_persistent_buffers_set.discard(name)

            else:

                self._non_persistent_buffers_set.add(name)

register_forward_hook

def register_forward_hook(
    self,
    hook: Union[Callable[[~T, Tuple[Any, ...], Any], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]],
    *,
    prepend: bool = False,
    with_kwargs: bool = False,
    always_call: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a forward hook on the module.

The hook will be called every time after :func:forward has computed an output.

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:forward is called. The hook should have the following signature::

hook(module, args, output) -> None or modified output

If with_kwargs is True, the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature::

hook(module, args, kwargs, output) -> None or modified output

Parameters:

Name Type Description Default
hook Callable The user defined hook to be registered. None
prepend bool If True, the provided hook will be fired
before all existing forward hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing forward hooks on
this :class:torch.nn.modules.Module. Note that global
forward hooks registered with
:func:register_module_forward_hook will fire before all hooks
registered by this method.
Default: False
None
with_kwargs bool If True, the hook will be passed the
kwargs given to the forward function.
Default: False
None
always_call bool If True the hook will be run regardless of
whether an exception is raised while calling the Module.
Default: False
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_forward_hook(

        self,

        hook: Union[

            Callable[[T, Tuple[Any, ...], Any], Optional[Any]],

            Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],

        ],

        *,

        prepend: bool = False,

        with_kwargs: bool = False,

        always_call: bool = False,

    ) -> RemovableHandle:

        r"""Register a forward hook on the module.

        The hook will be called every time after :func:`forward` has computed an output.

        If ``with_kwargs`` is ``False`` or not specified, the input contains only

        the positional arguments given to the module. Keyword arguments won't be

        passed to the hooks and only to the ``forward``. The hook can modify the

        output. It can modify the input inplace but it will not have effect on

        forward since this is called after :func:`forward` is called. The hook

        should have the following signature::

            hook(module, args, output) -> None or modified output

        If ``with_kwargs`` is ``True``, the forward hook will be passed the

        ``kwargs`` given to the forward function and be expected to return the

        output possibly modified. The hook should have the following signature::

            hook(module, args, kwargs, output) -> None or modified output

        Args:

            hook (Callable): The user defined hook to be registered.

            prepend (bool): If ``True``, the provided ``hook`` will be fired

                before all existing ``forward`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``forward`` hooks on

                this :class:`torch.nn.modules.Module`. Note that global

                ``forward`` hooks registered with

                :func:`register_module_forward_hook` will fire before all hooks

                registered by this method.

                Default: ``False``

            with_kwargs (bool): If ``True``, the ``hook`` will be passed the

                kwargs given to the forward function.

                Default: ``False``

            always_call (bool): If ``True`` the ``hook`` will be run regardless of

                whether an exception is raised while calling the Module.

                Default: ``False``

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(

            self._forward_hooks,

            extra_dict=[self._forward_hooks_with_kwargs, self._forward_hooks_always_called],

        )

        self._forward_hooks[handle.id] = hook

        if with_kwargs:

            self._forward_hooks_with_kwargs[handle.id] = True

        if always_call:

            self._forward_hooks_always_called[handle.id] = True

        if prepend:

            self._forward_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_forward_pre_hook

def register_forward_pre_hook(
    self,
    hook: Union[Callable[[~T, Tuple[Any, ...]], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]],
    *,
    prepend: bool = False,
    with_kwargs: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a forward pre-hook on the module.

The hook will be called every time before :func:forward is invoked.

If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the forward. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature::

hook(module, args) -> None or modified input

If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature::

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs

Parameters:

Name Type Description Default
hook Callable The user defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing forward_pre hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing forward_pre hooks
on this :class:torch.nn.modules.Module. Note that global
forward_pre hooks registered with
:func:register_module_forward_pre_hook will fire before all
hooks registered by this method.
Default: False
None
with_kwargs bool If true, the hook will be passed the kwargs
given to the forward function.
Default: False
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_forward_pre_hook(

        self,

        hook: Union[

            Callable[[T, Tuple[Any, ...]], Optional[Any]],

            Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]],

        ],

        *,

        prepend: bool = False,

        with_kwargs: bool = False,

    ) -> RemovableHandle:

        r"""Register a forward pre-hook on the module.

        The hook will be called every time before :func:`forward` is invoked.

        If ``with_kwargs`` is false or not specified, the input contains only

        the positional arguments given to the module. Keyword arguments won't be

        passed to the hooks and only to the ``forward``. The hook can modify the

        input. User can either return a tuple or a single modified value in the

        hook. We will wrap the value into a tuple if a single value is returned

        (unless that value is already a tuple). The hook should have the

        following signature::

            hook(module, args) -> None or modified input

        If ``with_kwargs`` is true, the forward pre-hook will be passed the

        kwargs given to the forward function. And if the hook modifies the

        input, both the args and kwargs should be returned. The hook should have

        the following signature::

            hook(module, args, kwargs) -> None or a tuple of modified input and kwargs

        Args:

            hook (Callable): The user defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``forward_pre`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``forward_pre`` hooks

                on this :class:`torch.nn.modules.Module`. Note that global

                ``forward_pre`` hooks registered with

                :func:`register_module_forward_pre_hook` will fire before all

                hooks registered by this method.

                Default: ``False``

            with_kwargs (bool): If true, the ``hook`` will be passed the kwargs

                given to the forward function.

                Default: ``False``

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(

            self._forward_pre_hooks,

            extra_dict=self._forward_pre_hooks_with_kwargs

        )

        self._forward_pre_hooks[handle.id] = hook

        if with_kwargs:

            self._forward_pre_hooks_with_kwargs[handle.id] = True

        if prepend:

            self._forward_pre_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_full_backward_hook

def register_full_backward_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]],
    prepend: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a backward hook on the module.

The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature::

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The :attr:grad_input and :attr:grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:grad_input in subsequent computations. :attr:grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:grad_input and :attr:grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.

.. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.

Parameters:

Name Type Description Default
hook Callable The user-defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing backward hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing backward hooks on
this :class:torch.nn.modules.Module. Note that global
backward hooks registered with
:func:register_module_full_backward_hook will fire before
all hooks registered by this method.
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_full_backward_hook(

        self,

        hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]],

        prepend: bool = False,

    ) -> RemovableHandle:

        r"""Register a backward hook on the module.

        The hook will be called every time the gradients with respect to a module

        are computed, i.e. the hook will execute if and only if the gradients with

        respect to module outputs are computed. The hook should have the following

        signature::

            hook(module, grad_input, grad_output) -> tuple(Tensor) or None

        The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients

        with respect to the inputs and outputs respectively. The hook should

        not modify its arguments, but it can optionally return a new gradient with

        respect to the input that will be used in place of :attr:`grad_input` in

        subsequent computations. :attr:`grad_input` will only correspond to the inputs given

        as positional arguments and all kwarg arguments are ignored. Entries

        in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor

        arguments.

        For technical reasons, when this hook is applied to a Module, its forward function will

        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view

        of each Tensor returned by the Module's forward function.

        .. warning ::

            Modifying inputs or outputs inplace is not allowed when using backward hooks and

            will raise an error.

        Args:

            hook (Callable): The user-defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``backward`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``backward`` hooks on

                this :class:`torch.nn.modules.Module`. Note that global

                ``backward`` hooks registered with

                :func:`register_module_full_backward_hook` will fire before

                all hooks registered by this method.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        if self._is_full_backward_hook is False:

            raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "

                               "single Module. Please use only one of them.")

        self._is_full_backward_hook = True

        handle = hooks.RemovableHandle(self._backward_hooks)

        self._backward_hooks[handle.id] = hook

        if prepend:

            self._backward_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_full_backward_pre_hook

def register_full_backward_pre_hook(
    self,
    hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]],
    prepend: bool = False
) -> torch.utils.hooks.RemovableHandle

Register a backward pre-hook on the module.

The hook will be called every time the gradients for the module are computed. The hook should have the following signature::

hook(module, grad_output) -> tuple[Tensor] or None

The :attr:grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:grad_output in subsequent computations. Entries in :attr:grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.

.. warning :: Modifying inputs inplace is not allowed when using backward hooks and will raise an error.

Parameters:

Name Type Description Default
hook Callable The user-defined hook to be registered. None
prepend bool If true, the provided hook will be fired before
all existing backward_pre hooks on this
:class:torch.nn.modules.Module. Otherwise, the provided
hook will be fired after all existing backward_pre hooks
on this :class:torch.nn.modules.Module. Note that global
backward_pre hooks registered with
:func:register_module_full_backward_pre_hook will fire before
all hooks registered by this method.
None

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_full_backward_pre_hook(

        self,

        hook: Callable[["Module", _grad_t], Union[None, _grad_t]],

        prepend: bool = False,

    ) -> RemovableHandle:

        r"""Register a backward pre-hook on the module.

        The hook will be called every time the gradients for the module are computed.

        The hook should have the following signature::

            hook(module, grad_output) -> tuple[Tensor] or None

        The :attr:`grad_output` is a tuple. The hook should

        not modify its arguments, but it can optionally return a new gradient with

        respect to the output that will be used in place of :attr:`grad_output` in

        subsequent computations. Entries in :attr:`grad_output` will be ``None`` for

        all non-Tensor arguments.

        For technical reasons, when this hook is applied to a Module, its forward function will

        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view

        of each Tensor returned by the Module's forward function.

        .. warning ::

            Modifying inputs inplace is not allowed when using backward hooks and

            will raise an error.

        Args:

            hook (Callable): The user-defined hook to be registered.

            prepend (bool): If true, the provided ``hook`` will be fired before

                all existing ``backward_pre`` hooks on this

                :class:`torch.nn.modules.Module`. Otherwise, the provided

                ``hook`` will be fired after all existing ``backward_pre`` hooks

                on this :class:`torch.nn.modules.Module`. Note that global

                ``backward_pre`` hooks registered with

                :func:`register_module_full_backward_pre_hook` will fire before

                all hooks registered by this method.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(self._backward_pre_hooks)

        self._backward_pre_hooks[handle.id] = hook

        if prepend:

            self._backward_pre_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]

        return handle

register_load_state_dict_post_hook

def register_load_state_dict_post_hook(
    self,
    hook
)

Register a post hook to be run after module's load_state_dict is called.

It should have the following signature:: hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys.

The given incompatible_keys can be modified inplace if needed.

Note that the checks performed when calling :func:load_state_dict with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error.

Returns:

Type Description
None :class:torch.utils.hooks.RemovableHandle:
a handle that can be used to remove the added hook by calling
handle.remove()
View Source
    def register_load_state_dict_post_hook(self, hook):

        r"""Register a post hook to be run after module's ``load_state_dict`` is called.

        It should have the following signature::

            hook(module, incompatible_keys) -> None

        The ``module`` argument is the current module that this hook is registered

        on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting

        of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``

        is a ``list`` of ``str`` containing the missing keys and

        ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.

        The given incompatible_keys can be modified inplace if needed.

        Note that the checks performed when calling :func:`load_state_dict` with

        ``strict=True`` are affected by modifications the hook makes to

        ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either

        set of keys will result in an error being thrown when ``strict=True``, and

        clearing out both missing and unexpected keys will avoid an error.

        Returns:

            :class:`torch.utils.hooks.RemovableHandle`:

                a handle that can be used to remove the added hook by calling

                ``handle.remove()``

        """

        handle = hooks.RemovableHandle(self._load_state_dict_post_hooks)

        self._load_state_dict_post_hooks[handle.id] = hook

        return handle

register_module

def register_module(
    self,
    name: str,
    module: Optional[ForwardRef('Module')]
) -> None

Alias for :func:add_module.

View Source
    def register_module(self, name: str, module: Optional['Module']) -> None:

        r"""Alias for :func:`add_module`."""

        self.add_module(name, module)

register_parameter

def register_parameter(
    self,
    name: str,
    param: Optional[torch.nn.parameter.Parameter]
) -> None

Add a parameter to the module.

The parameter can be accessed as an attribute using given name.

Parameters:

Name Type Description Default
name str name of the parameter. The parameter can be accessed
from this module using the given name
None
param Parameter or None parameter to be added to the module. If
None, then operations that run on parameters, such as :attr:cuda,
are ignored. If None, the parameter is not included in the
module's :attr:state_dict.
None
View Source
    def register_parameter(self, name: str, param: Optional[Parameter]) -> None:

        r"""Add a parameter to the module.

        The parameter can be accessed as an attribute using given name.

        Args:

            name (str): name of the parameter. The parameter can be accessed

                from this module using the given name

            param (Parameter or None): parameter to be added to the module. If

                ``None``, then operations that run on parameters, such as :attr:`cuda`,

                are ignored. If ``None``, the parameter is **not** included in the

                module's :attr:`state_dict`.

        """

        if '_parameters' not in self.__dict__:

            raise AttributeError(

                "cannot assign parameter before Module.__init__() call")

        elif not isinstance(name, str):

            raise TypeError(f"parameter name should be a string. Got {torch.typename(name)}")

        elif '.' in name:

            raise KeyError("parameter name can't contain \".\"")

        elif name == '':

            raise KeyError("parameter name can't be empty string \"\"")

        elif hasattr(self, name) and name not in self._parameters:

            raise KeyError(f"attribute '{name}' already exists")

        if param is None:

            self._parameters[name] = None

        elif not isinstance(param, Parameter):

            raise TypeError(f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "

                            "(torch.nn.Parameter or None required)"

                            )

        elif param.grad_fn:

            raise ValueError(

                f"Cannot assign non-leaf Tensor to parameter '{name}'. Model "

                f"parameters must be created explicitly. To express '{name}' "

                "as a function of another Tensor, compute the value in "

                "the forward() method.")

        else:

            for hook in _global_parameter_registration_hooks.values():

                output = hook(self, name, param)

                if output is not None:

                    param = output

            self._parameters[name] = param

register_state_dict_pre_hook

def register_state_dict_pre_hook(
    self,
    hook
)

Register a pre-hook for the :meth:~torch.nn.Module.state_dict method.

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.

View Source
    def register_state_dict_pre_hook(self, hook):

        r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.

        These hooks will be called with arguments: ``self``, ``prefix``,

        and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered

        hooks can be used to perform pre-processing before the ``state_dict``

        call is made.

        """

        handle = hooks.RemovableHandle(self._state_dict_pre_hooks)

        self._state_dict_pre_hooks[handle.id] = hook

        return handle

requires_grad_

def requires_grad_(
    self: ~T,
    requires_grad: bool = True
) -> ~T

Change if autograd should record operations on parameters in this module.

This method sets the parameters' :attr:requires_grad attributes in-place.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).

See :ref:locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it.

Parameters:

Name Type Description Default
requires_grad bool whether autograd should record operations on
parameters in this module. Default: True.
None

Returns:

Type Description
Module self
View Source
    def requires_grad_(self: T, requires_grad: bool = True) -> T:

        r"""Change if autograd should record operations on parameters in this module.

        This method sets the parameters' :attr:`requires_grad` attributes

        in-place.

        This method is helpful for freezing part of the module for finetuning

        or training parts of a model individually (e.g., GAN training).

        See :ref:`locally-disable-grad-doc` for a comparison between

        `.requires_grad_()` and several similar mechanisms that may be confused with it.

        Args:

            requires_grad (bool): whether autograd should record operations on

                                  parameters in this module. Default: ``True``.

        Returns:

            Module: self

        """

        for p in self.parameters():

            p.requires_grad_(requires_grad)

        return self

set_extra_state

def set_extra_state(
    self,
    state: Any
) -> None

Set extra state contained in the loaded state_dict.

This function is called from :func:load_state_dict to handle any extra state found within the state_dict. Implement this function and a corresponding

View Source
    def set_extra_state(self, state: Any) -> None:

        """Set extra state contained in the loaded `state_dict`.

        This function is called from :func:`load_state_dict` to handle any extra state

        found within the `state_dict`. Implement this function and a corresponding

        :func:`get_extra_state` for your module if you need to store extra state within its

        `state_dict`.

        Args:

            state (dict): Extra state from the `state_dict`

        """

        raise RuntimeError(

            "Reached a code path in Module.set_extra_state() that should never be called. "

            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

            "to report this bug.")

share_memory

def share_memory(
    self: ~T
) -> ~T

See :meth:torch.Tensor.share_memory_.

View Source
    def share_memory(self: T) -> T:

        r"""See :meth:`torch.Tensor.share_memory_`."""

        return self._apply(lambda t: t.share_memory_())

state_dict

def state_dict(
    self,
    *args,
    destination=None,
    prefix='',
    keep_vars=False
)

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

.. note:: The returned object is a shallow copy. It contains references to the module's parameters and buffers.

.. warning:: Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

.. warning:: Please avoid the use of argument destination as it is not designed for end-users.

Parameters:

Name Type Description Default
destination dict If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict will be created and returned.
Default: None.
None
prefix str a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''.
None
keep_vars bool by default the :class:~torch.Tensor s
returned in the state dict are detached from autograd. If it's
set to True, detaching will not be performed.
Default: False.
None

Returns:

Type Description
dict a dictionary containing a whole state of the module
View Source
    def state_dict(self, *args, destination=None, prefix='', keep_vars=False):

        r"""Return a dictionary containing references to the whole state of the module.

        Both parameters and persistent buffers (e.g. running averages) are

        included. Keys are corresponding parameter and buffer names.

        Parameters and buffers set to ``None`` are not included.

        .. note::

            The returned object is a shallow copy. It contains references

            to the module's parameters and buffers.

        .. warning::

            Currently ``state_dict()`` also accepts positional arguments for

            ``destination``, ``prefix`` and ``keep_vars`` in order. However,

            this is being deprecated and keyword arguments will be enforced in

            future releases.

        .. warning::

            Please avoid the use of argument ``destination`` as it is not

            designed for end-users.

        Args:

            destination (dict, optional): If provided, the state of module will

                be updated into the dict and the same object is returned.

                Otherwise, an ``OrderedDict`` will be created and returned.

                Default: ``None``.

            prefix (str, optional): a prefix added to parameter and buffer

                names to compose the keys in state_dict. Default: ``''``.

            keep_vars (bool, optional): by default the :class:`~torch.Tensor` s

                returned in the state dict are detached from autograd. If it's

                set to ``True``, detaching will not be performed.

                Default: ``False``.

        Returns:

            dict:

                a dictionary containing a whole state of the module

        Example::

            >>> # xdoctest: +SKIP("undefined vars")

            >>> module.state_dict().keys()

            ['bias', 'weight']

        """

        # TODO: Remove `args` and the parsing logic when BC allows.

        if len(args) > 0:

            if destination is None:

                destination = args[0]

            if len(args) > 1 and prefix == '':

                prefix = args[1]

            if len(args) > 2 and keep_vars is False:

                keep_vars = args[2]

            # DeprecationWarning is ignored by default

            warnings.warn(

                "Positional args are being deprecated, use kwargs instead. Refer to "

                "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"

                " for details.")

        if destination is None:

            destination = OrderedDict()

            destination._metadata = OrderedDict()

        local_metadata = dict(version=self._version)

        if hasattr(destination, "_metadata"):

            destination._metadata[prefix[:-1]] = local_metadata

        for hook in self._state_dict_pre_hooks.values():

            hook(self, prefix, keep_vars)

        self._save_to_state_dict(destination, prefix, keep_vars)

        for name, module in self._modules.items():

            if module is not None:

                module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)

        for hook in self._state_dict_hooks.values():

            hook_result = hook(self, destination, prefix, local_metadata)

            if hook_result is not None:

                destination = hook_result

        return destination

to

def to(
    self,
    *args,
    **kwargs
)

Move and/or cast the parameters and buffers.

This can be called as

.. function:: to(device=None, dtype=None, non_blocking=False) :noindex:

.. function:: to(dtype, non_blocking=False) :noindex:

.. function:: to(tensor, non_blocking=False) :noindex:

.. function:: to(memory_format=torch.channels_last) :noindex:

Its signature is similar to :meth:torch.Tensor.to, but only accepts floating point or complex :attr:dtype\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:dtype (if given). The integral parameters and buffers will be moved :attr:device, if that is given, but with dtypes unchanged. When :attr:non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device ( None class:torch.device): the desired device of the parameters
and buffers in this module
None
dtype ( None class:torch.dtype): the desired floating point or complex dtype of
the parameters and buffers in this module
None
tensor torch.Tensor Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
None
memory_format ( None class:torch.memory_format): the desired memory
format for 4D parameters and buffers in this module (keyword
only argument)
None

Returns:

Type Description
Module self
View Source
    def to(self, *args, **kwargs):

        r"""Move and/or cast the parameters and buffers.

        This can be called as

        .. function:: to(device=None, dtype=None, non_blocking=False)

           :noindex:

        .. function:: to(dtype, non_blocking=False)

           :noindex:

        .. function:: to(tensor, non_blocking=False)

           :noindex:

        .. function:: to(memory_format=torch.channels_last)

           :noindex:

        Its signature is similar to :meth:`torch.Tensor.to`, but only accepts

        floating point or complex :attr:`dtype`\ s. In addition, this method will

        only cast the floating point or complex parameters and buffers to :attr:`dtype`

        (if given). The integral parameters and buffers will be moved

        :attr:`device`, if that is given, but with dtypes unchanged. When

        :attr:`non_blocking` is set, it tries to convert/move asynchronously

        with respect to the host if possible, e.g., moving CPU Tensors with

        pinned memory to CUDA devices.

        See below for examples.

        .. note::

            This method modifies the module in-place.

        Args:

            device (:class:`torch.device`): the desired device of the parameters

                and buffers in this module

            dtype (:class:`torch.dtype`): the desired floating point or complex dtype of

                the parameters and buffers in this module

            tensor (torch.Tensor): Tensor whose dtype and device are the desired

                dtype and device for all parameters and buffers in this module

            memory_format (:class:`torch.memory_format`): the desired memory

                format for 4D parameters and buffers in this module (keyword

                only argument)

        Returns:

            Module: self

        Examples::

            >>> # xdoctest: +IGNORE_WANT("non-deterministic")

            >>> linear = nn.Linear(2, 2)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1913, -0.3420],

                    [-0.5113, -0.2325]])

            >>> linear.to(torch.double)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1913, -0.3420],

                    [-0.5113, -0.2325]], dtype=torch.float64)

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)

            >>> gpu1 = torch.device("cuda:1")

            >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1914, -0.3420],

                    [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')

            >>> cpu = torch.device("cpu")

            >>> linear.to(cpu)

            Linear(in_features=2, out_features=2, bias=True)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.1914, -0.3420],

                    [-0.5112, -0.2324]], dtype=torch.float16)

            >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)

            >>> linear.weight

            Parameter containing:

            tensor([[ 0.3741+0.j,  0.2382+0.j],

                    [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)

            >>> linear(torch.ones(3, 2, dtype=torch.cdouble))

            tensor([[0.6122+0.j, 0.1150+0.j],

                    [0.6122+0.j, 0.1150+0.j],

                    [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)

        """

        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

        if dtype is not None:

            if not (dtype.is_floating_point or dtype.is_complex):

                raise TypeError('nn.Module.to only accepts floating point or complex '

                                f'dtypes, but got desired dtype={dtype}')

            if dtype.is_complex:

                warnings.warn(

                    "Complex modules are a new feature under active development whose design may change, "

                    "and some modules might not work as expected when using complex tensors as parameters or buffers. "

                    "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "

                    "if a complex module does not work as expected.")

        def convert(t):

            try:

                if convert_to_format is not None and t.dim() in (4, 5):

                    return t.to(

                        device,

                        dtype if t.is_floating_point() or t.is_complex() else None,

                        non_blocking,

                        memory_format=convert_to_format,

                    )

                return t.to(

                    device,

                    dtype if t.is_floating_point() or t.is_complex() else None,

                    non_blocking,

                )

            except NotImplementedError as e:

                if str(e) == "Cannot copy out of meta tensor; no data!":

                    raise NotImplementedError(

                        f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "

                        f"when moving module from meta to a different device."

                    ) from None

                else:

                    raise

        return self._apply(convert)

to_empty

def to_empty(
    self: ~T,
    *,
    device: Union[int, str, torch.device, NoneType],
    recurse: bool = True
) -> ~T

Move the parameters and buffers to the specified device without copying storage.

Parameters:

Name Type Description Default
device ( None class:torch.device): The desired device of the parameters
and buffers in this module.
None
recurse bool Whether parameters and buffers of submodules should
be recursively moved to the specified device.
None

Returns:

Type Description
Module self
View Source
    def to_empty(self: T, *, device: Optional[DeviceLikeType], recurse: bool = True) -> T:

        r"""Move the parameters and buffers to the specified device without copying storage.

        Args:

            device (:class:`torch.device`): The desired device of the parameters

                and buffers in this module.

            recurse (bool): Whether parameters and buffers of submodules should

                be recursively moved to the specified device.

        Returns:

            Module: self

        """

        return self._apply(lambda t: torch.empty_like(t, device=device), recurse=recurse)

train

def train(
    self: ~T,
    mode: bool = True
) -> ~T

Set the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

Parameters:

Name Type Description Default
mode bool whether to set training mode (True) or evaluation
mode (False). Default: True.
None

Returns:

Type Description
Module self
View Source
    def train(self: T, mode: bool = True) -> T:

        r"""Set the module in training mode.

        This has any effect only on certain modules. See documentations of

        particular modules for details of their behaviors in training/evaluation

        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,

        etc.

        Args:

            mode (bool): whether to set training mode (``True``) or evaluation

                         mode (``False``). Default: ``True``.

        Returns:

            Module: self

        """

        if not isinstance(mode, bool):

            raise ValueError("training mode is expected to be boolean")

        self.training = mode

        for module in self.children():

            module.train(mode)

        return self

type

def type(
    self: ~T,
    dst_type: Union[torch.dtype, str]
) -> ~T

Casts all parameters and buffers to :attr:dst_type.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
dst_type type or string the desired type None

Returns:

Type Description
Module self
View Source
    def type(self: T, dst_type: Union[dtype, str]) -> T:

        r"""Casts all parameters and buffers to :attr:`dst_type`.

        .. note::

            This method modifies the module in-place.

        Args:

            dst_type (type or string): the desired type

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.type(dst_type))

xpu

def xpu(
    self: ~T,
    device: Union[int, torch.device, NoneType] = None
) -> ~T

Move all model parameters and buffers to the XPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.

.. note:: This method modifies the module in-place.

Parameters:

Name Type Description Default
device int if specified, all parameters will be
copied to that device
None

Returns:

Type Description
Module self
View Source
    def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:

        r"""Move all model parameters and buffers to the XPU.

        This also makes associated parameters and buffers different objects. So

        it should be called before constructing optimizer if the module will

        live on XPU while being optimized.

        .. note::

            This method modifies the module in-place.

        Arguments:

            device (int, optional): if specified, all parameters will be

                copied to that device

        Returns:

            Module: self

        """

        return self._apply(lambda t: t.xpu(device))

zero_grad

def zero_grad(
    self,
    set_to_none: bool = True
) -> None

Reset gradients of all model parameters.

See similar function under :class:torch.optim.Optimizer for more context.

Parameters:

Name Type Description Default
set_to_none bool instead of setting to zero, set the grads to None.
See :meth:torch.optim.Optimizer.zero_grad for details.
None
View Source
    def zero_grad(self, set_to_none: bool = True) -> None:

        r"""Reset gradients of all model parameters.

        See similar function under :class:`torch.optim.Optimizer` for more context.

        Args:

            set_to_none (bool): instead of setting to zero, set the grads to None.

                See :meth:`torch.optim.Optimizer.zero_grad` for details.

        """

        if getattr(self, '_is_replica', False):

            warnings.warn(

                "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "

                "The parameters are copied (in a differentiable manner) from the original module. "

                "This means they are not leaf nodes in autograd and so don't accumulate gradients. "

                "If you need gradients in your forward method, consider using autograd.grad instead.")

        for p in self.parameters():

            if p.grad is not None:

                if set_to_none:

                    p.grad = None

                else:

                    if p.grad.grad_fn is not None:

                        p.grad.detach_()

                    else:

                        p.grad.requires_grad_(False)

                    p.grad.zero_()