diff --git a/Functional2Layer.py b/Functional2Layer.py new file mode 100644 index 0000000..4ca3d30 --- /dev/null +++ b/Functional2Layer.py @@ -0,0 +1,41 @@ +import torch +from typing import Callable, Any + + +class Functional2Layer(torch.nn.Module): + def __init__( + self, func: Callable[..., torch.Tensor], *args: Any, **kwargs: Any + ) -> None: + super().__init__() + self.func = func + self.args = args + self.kwargs = kwargs + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return self.func(input, *self.args, **self.kwargs) + + def extra_repr(self) -> str: + func_name = ( + self.func.__name__ if hasattr(self.func, "__name__") else str(self.func) + ) + args_repr = ", ".join(map(repr, self.args)) + kwargs_repr = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) + return f"func={func_name}, args=({args_repr}), kwargs={{{kwargs_repr}}}" + + +if __name__ == "__main__": + + print("Permute Example") + test_layer_permute = Functional2Layer(func=torch.permute, dims=(0, 2, 3, 1)) + input = torch.zeros((10, 11, 12, 13)) + output = test_layer_permute(input) + print(input.shape) + print(output.shape) + print(test_layer_permute) + + print() + print("Clamp Example") + test_layer_clamp = Functional2Layer(func=torch.clamp, min=5, max=100) + output = test_layer_permute(input) + print(output[0, 0, 0, 0]) + print(test_layer_clamp) diff --git a/SequentialSplit.py b/SequentialSplit.py new file mode 100644 index 0000000..c20dcbe --- /dev/null +++ b/SequentialSplit.py @@ -0,0 +1,169 @@ +import torch +from typing import Callable + + +class SequentialSplit(torch.nn.Module): + """ + A PyTorch module that splits the processing path of a input tensor + and processes it through multiple torch.nn.Sequential segments, + then combines the outputs using a specified methods. + + This module allows for creating split paths within a `torch.nn.Sequential` + model, making it possible to implement architectures with skip connections + or parallel paths without abandoning the sequential model structure. + + Attributes: + segments (torch.nn.Sequential[torch.nn.Sequential]): A list of sequential modules to + process the input tensor. + combine_func (Callable | None): A function to combine the outputs + from the segments. + dim (int | None): The dimension along which to concatenate + the outputs if `combine_func` is `torch.cat`. + + Args: + segments (torch.nn.Sequential[torch.nn.Sequential]): A torch.nn.Sequential + with a list of sequential modules to process the input tensor. + combine (str, optional): The method to combine the outputs. + "cat" for concatenation (default), "sum" for a summation, + or "func" to use a custom combine function. + dim (int | None, optional): The dimension along which to + concatenate the outputs if `combine` is "cat". + Defaults to 1. + combine_func (Callable | None, optional): A custom function + to combine the outputs if `combine` is "func". + Defaults to None. + + Example: + A simple example for the `SequentialSplit` module with two sub-torch.nn.Sequential: + + ----- segment_a ----- + main_Sequential ----| |---- main_Sequential + ----- segment_b ----- + + segments = [segment_a, segment_b] + y_split = SequentialSplit(segments) + result = y_split(input_tensor) + + Methods: + forward(input: torch.Tensor) -> torch.Tensor: + Processes the input tensor through the segments and + combines the results. + """ + + segments: torch.nn.Sequential + combine_func: Callable + dim: int | None + + def __init__( + self, + segments: torch.nn.Sequential, + combine: str = "cat", # "cat", "sum", "func", + dim: int | None = 1, + combine_func: Callable | None = None, + ): + super().__init__() + self.segments = segments + self.dim = dim + + self.combine = combine + + if combine.upper() == "CAT": + self.combine_func = torch.cat + elif combine.upper() == "SUM": + self.combine_func = self.sum + self.dim = None + else: + assert combine_func is not None + self.combine_func = combine_func + + def sum(self, input: list[torch.Tensor]) -> torch.Tensor | None: + + if len(input) == 0: + return None + + if len(input) == 1: + return input[0] + + output: torch.Tensor = input[0] + + for i in range(1, len(input)): + output = output + input[i] + + return output + + def forward(self, input: torch.Tensor) -> torch.Tensor: + results: list[torch.Tensor] = [] + for segment in self.segments: + results.append(segment(input)) + + if self.dim is None: + return self.combine_func(results) + else: + return self.combine_func(results, dim=self.dim) + + def extra_repr(self) -> str: + return self.combine + + +if __name__ == "__main__": + + print("Example CAT") + strain_a = torch.nn.Sequential(torch.nn.Identity()) + strain_b = torch.nn.Sequential(torch.nn.Identity()) + strain_c = torch.nn.Sequential(torch.nn.Identity()) + test_cat = SequentialSplit( + torch.nn.Sequential(strain_a, strain_b, strain_c), combine="cat", dim=2 + ) + print(test_cat) + input = torch.ones((10, 11, 12, 13)) + output = test_cat(input) + print(input.shape) + print(output.shape) + print(input[0, 0, 0, 0]) + print(output[0, 0, 0, 0]) + print() + + print("Example SUM") + strain_a = torch.nn.Sequential(torch.nn.Identity()) + strain_b = torch.nn.Sequential(torch.nn.Identity()) + strain_c = torch.nn.Sequential(torch.nn.Identity()) + test_sum = SequentialSplit( + torch.nn.Sequential(strain_a, strain_b, strain_c), combine="sum", dim=2 + ) + print(test_sum) + input = torch.ones((10, 11, 12, 13)) + output = test_sum(input) + print(input.shape) + print(output.shape) + print(input[0, 0, 0, 0]) + print(output[0, 0, 0, 0]) + print() + + print("Example Labeling") + strain_a = torch.nn.Sequential() + strain_a.add_module("Label for first strain", torch.nn.Identity()) + strain_b = torch.nn.Sequential() + strain_b.add_module("Label for second strain", torch.nn.Identity()) + strain_c = torch.nn.Sequential() + strain_c.add_module("Label for third strain", torch.nn.Identity()) + test_label = SequentialSplit(torch.nn.Sequential(strain_a, strain_b, strain_c)) + print(test_label) + print() + + print("Example Get Parameter") + input = torch.ones((10, 11, 12, 13)) + strain_a = torch.nn.Sequential() + strain_a.add_module("Identity", torch.nn.Identity()) + strain_b = torch.nn.Sequential() + strain_b.add_module( + "Conv2d", + torch.nn.Conv2d( + in_channels=input.shape[1], + out_channels=input.shape[1], + kernel_size=(1, 1), + ), + ) + test_parameter = SequentialSplit(torch.nn.Sequential(strain_a, strain_b)) + print(test_parameter) + for name, param in test_parameter.named_parameters(): + print(f"Parameter name: {name}, Shape: {param.shape}") diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..664080c --- /dev/null +++ b/__init__.py @@ -0,0 +1,41 @@ +from . import parametrizations, rnn, stateless +from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_ +from .convert_parameters import parameters_to_vector, vector_to_parameters +from .fusion import ( + fuse_conv_bn_eval, + fuse_conv_bn_weights, + fuse_linear_bn_eval, + fuse_linear_bn_weights, +) +from .init import skip_init +from .memory_format import ( + convert_conv2d_weight_memory_format, + convert_conv3d_weight_memory_format, +) +from .spectral_norm import remove_spectral_norm, spectral_norm +from .weight_norm import remove_weight_norm, weight_norm + +from .Functional2Layer import Functional2Layer + +__all__ = [ + "clip_grad_norm", + "clip_grad_norm_", + "clip_grad_value_", + "convert_conv2d_weight_memory_format", + "convert_conv3d_weight_memory_format", + "fuse_conv_bn_eval", + "fuse_conv_bn_weights", + "fuse_linear_bn_eval", + "fuse_linear_bn_weights", + "parameters_to_vector", + "parametrizations", + "remove_spectral_norm", + "remove_weight_norm", + "rnn", + "skip_init", + "spectral_norm", + "stateless", + "vector_to_parameters", + "weight_norm", + "Functional2Layer", +] diff --git a/container.py b/container.py new file mode 100644 index 0000000..68bb63b --- /dev/null +++ b/container.py @@ -0,0 +1,968 @@ +# mypy: allow-untyped-defs +import operator +from collections import abc as container_abcs, OrderedDict +from itertools import chain, islice +from typing import ( + Any, + Dict, + Iterable, + Iterator, + Mapping, + Optional, + overload, + Tuple, + TypeVar, + Union, +) +from typing_extensions import deprecated, Self + +import torch +from torch._jit_internal import _copy_to_script_wrapper +from torch.nn.parameter import Parameter + +from .module import Module +from .SequentialSplit import SequentialSplit + +__all__ = [ + "Container", + "Sequential", + "ModuleList", + "ModuleDict", + "ParameterList", + "ParameterDict", + "SequentialSplit", +] + +T = TypeVar("T", bound=Module) + + +# Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList +def _addindent(s_, numSpaces): + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + +@deprecated( + "`nn.Container` is deprecated. " + "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.", + category=FutureWarning, +) +class Container(Module): + def __init__(self, **kwargs: Any) -> None: + super().__init__() + for key, value in kwargs.items(): + self.add_module(key, value) + + +class Sequential(Module): + r"""A sequential container. + + Modules will be added to it in the order they are passed in the + constructor. Alternatively, an ``OrderedDict`` of modules can be + passed in. The ``forward()`` method of ``Sequential`` accepts any + input and forwards it to the first module it contains. It then + "chains" outputs to inputs sequentially for each subsequent module, + finally returning the output of the last module. + + The value a ``Sequential`` provides over manually calling a sequence + of modules is that it allows treating the whole container as a + single module, such that performing a transformation on the + ``Sequential`` applies to each of the modules it stores (which are + each a registered submodule of the ``Sequential``). + + What's the difference between a ``Sequential`` and a + :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it + sounds like--a list for storing ``Module`` s! On the other hand, + the layers in a ``Sequential`` are connected in a cascading way. + + Example:: + + # Using Sequential to create a small model. When `model` is run, + # input will first be passed to `Conv2d(1,20,5)`. The output of + # `Conv2d(1,20,5)` will be used as the input to the first + # `ReLU`; the output of the first `ReLU` will become the input + # for `Conv2d(20,64,5)`. Finally, the output of + # `Conv2d(20,64,5)` will be used as input to the second `ReLU` + model = nn.Sequential( + nn.Conv2d(1,20,5), + nn.ReLU(), + nn.Conv2d(20,64,5), + nn.ReLU() + ) + + # Using Sequential with OrderedDict. This is functionally the + # same as the above code + model = nn.Sequential(OrderedDict([ + ('conv1', nn.Conv2d(1,20,5)), + ('relu1', nn.ReLU()), + ('conv2', nn.Conv2d(20,64,5)), + ('relu2', nn.ReLU()) + ])) + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + @overload + def __init__(self, *args: Module) -> None: + ... + + @overload + def __init__(self, arg: "OrderedDict[str, Module]") -> None: + ... + + def __init__(self, *args): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var] + """Get the idx-th item of the iterator.""" + size = len(self) + idx = operator.index(idx) + if not -size <= idx < size: + raise IndexError(f"index {idx} is out of range") + idx %= size + return next(islice(iterator, idx, None)) + + @_copy_to_script_wrapper + def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]: + if isinstance(idx, slice): + return self.__class__(OrderedDict(list(self._modules.items())[idx])) + else: + return self._get_item_by_idx(self._modules.values(), idx) + + def __setitem__(self, idx: int, module: Module) -> None: + key: str = self._get_item_by_idx(self._modules.keys(), idx) + return setattr(self, key, module) + + def __delitem__(self, idx: Union[slice, int]) -> None: + if isinstance(idx, slice): + for key in list(self._modules.keys())[idx]: + delattr(self, key) + else: + key = self._get_item_by_idx(self._modules.keys(), idx) + delattr(self, key) + # To preserve numbering + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + @_copy_to_script_wrapper + def __len__(self) -> int: + return len(self._modules) + + def __add__(self, other) -> "Sequential": + if isinstance(other, Sequential): + ret = Sequential() + for layer in self: + ret.append(layer) + for layer in other: + ret.append(layer) + return ret + else: + raise ValueError( + "add operator supports only objects " + f"of Sequential class, but {str(type(other))} is given." + ) + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def __iadd__(self, other) -> Self: + if isinstance(other, Sequential): + offset = len(self) + for i, module in enumerate(other): + self.add_module(str(i + offset), module) + return self + else: + raise ValueError( + "add operator supports only objects " + f"of Sequential class, but {str(type(other))} is given." + ) + + def __mul__(self, other: int) -> "Sequential": + if not isinstance(other, int): + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) + else: + combined = Sequential() + offset = 0 + for _ in range(other): + for module in self: + combined.add_module(str(offset), module) + offset += 1 + return combined + + def __rmul__(self, other: int) -> "Sequential": + return self.__mul__(other) + + def __imul__(self, other: int) -> Self: + if not isinstance(other, int): + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) + else: + len_original = len(self) + offset = len(self) + for _ in range(other - 1): + for i in range(len_original): + self.add_module(str(i + offset), self._modules[str(i)]) + offset += len_original + return self + + @_copy_to_script_wrapper + def __dir__(self): + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + @_copy_to_script_wrapper + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + # NB: We can't really type check this function as the type of input + # may change dynamically (as is tested in + # TestScript.test_sequential_intermediary_types). Cannot annotate + # with Any as TorchScript expects a more precise type + def forward(self, input): + for module in self: + input = module(input) + return input + + def append(self, module: Module) -> "Sequential": + r"""Append a given module to the end. + + Args: + module (nn.Module): module to append + """ + self.add_module(str(len(self)), module) + return self + + def insert(self, index: int, module: Module) -> "Sequential": + if not isinstance(module, Module): + raise AssertionError(f"module should be of type: {Module}") + n = len(self._modules) + if not (-n <= index <= n): + raise IndexError(f"Index out of range: {index}") + if index < 0: + index += n + for i in range(n, index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + return self + + def extend(self, sequential) -> "Sequential": + for layer in sequential: + self.append(layer) + return self + + +class ModuleList(Module): + r"""Holds submodules in a list. + + :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but + modules it contains are properly registered, and will be visible by all + :class:`~torch.nn.Module` methods. + + Args: + modules (iterable, optional): an iterable of modules to add + + Example:: + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: + super().__init__() + if modules is not None: + self += modules + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules.""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError(f"index {idx} is out of range") + if idx < 0: + idx += len(self) + return str(idx) + + @_copy_to_script_wrapper + def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]: + if isinstance(idx, slice): + return self.__class__(list(self._modules.values())[idx]) + else: + return self._modules[self._get_abs_string_index(idx)] + + def __setitem__(self, idx: int, module: Module) -> None: + idx = self._get_abs_string_index(idx) + return setattr(self, str(idx), module) + + def __delitem__(self, idx: Union[int, slice]) -> None: + if isinstance(idx, slice): + for k in range(len(self._modules))[idx]: + delattr(self, str(k)) + else: + delattr(self, self._get_abs_string_index(idx)) + # To preserve numbering, self._modules is being reconstructed with modules after deletion + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + @_copy_to_script_wrapper + def __len__(self) -> int: + return len(self._modules) + + @_copy_to_script_wrapper + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + def __iadd__(self, modules: Iterable[Module]) -> Self: + return self.extend(modules) + + def __add__(self, other: Iterable[Module]) -> "ModuleList": + combined = ModuleList() + for i, module in enumerate(chain(self, other)): + combined.add_module(str(i), module) + return combined + + def __repr__(self): + """Return a custom repr for ModuleList that compresses repeated module representations.""" + list_of_reprs = [repr(item) for item in self] + if len(list_of_reprs) == 0: + return self._get_name() + "()" + + start_end_indices = [[0, 0]] + repeated_blocks = [list_of_reprs[0]] + for i, r in enumerate(list_of_reprs[1:], 1): + if r == repeated_blocks[-1]: + start_end_indices[-1][1] += 1 + continue + + start_end_indices.append([i, i]) + repeated_blocks.append(r) + + lines = [] + main_str = self._get_name() + "(" + for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): + local_repr = f"({start_id}): {b}" # default repr + + if start_id != end_id: + n = end_id - start_id + 1 + local_repr = f"({start_id}-{end_id}): {n} x {b}" + + local_repr = _addindent(local_repr, 2) + lines.append(local_repr) + + main_str += "\n " + "\n ".join(lines) + "\n" + main_str += ")" + return main_str + + @_copy_to_script_wrapper + def __dir__(self): + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def insert(self, index: int, module: Module) -> None: + r"""Insert a given module before a given index in the list. + + Args: + index (int): index to insert. + module (nn.Module): module to insert + """ + for i in range(len(self._modules), index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + + def append(self, module: Module) -> "ModuleList": + r"""Append a given module to the end of the list. + + Args: + module (nn.Module): module to append + """ + self.add_module(str(len(self)), module) + return self + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def extend(self, modules: Iterable[Module]) -> Self: + r"""Append modules from a Python iterable to the end of the list. + + Args: + modules (iterable): iterable of modules to append + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleList.extend should be called with an " + "iterable, but got " + type(modules).__name__ + ) + offset = len(self) + for i, module in enumerate(modules): + self.add_module(str(offset + i), module) + return self + + # remove forward alltogether to fallback on Module's _forward_unimplemented + + +class ModuleDict(Module): + r"""Holds submodules in a dictionary. + + :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary, + but modules it contains are properly registered, and will be visible by all + :class:`~torch.nn.Module` methods. + + :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects + + * the order of insertion, and + + * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged + ``OrderedDict``, ``dict`` (started from Python 3.6) or another + :class:`~torch.nn.ModuleDict` (the argument to + :meth:`~torch.nn.ModuleDict.update`). + + Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping + types (e.g., Python's plain ``dict`` before Python version 3.6) does not + preserve the order of the merged mapping. + + Args: + modules (iterable, optional): a mapping (dictionary) of (string: module) + or an iterable of key-value pairs of type (string, module) + + Example:: + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.choices = nn.ModuleDict({ + 'conv': nn.Conv2d(10, 10, 3), + 'pool': nn.MaxPool2d(3) + }) + self.activations = nn.ModuleDict([ + ['lrelu', nn.LeakyReLU()], + ['prelu', nn.PReLU()] + ]) + + def forward(self, x, choice, act): + x = self.choices[choice](x) + x = self.activations[act](x) + return x + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: + super().__init__() + if modules is not None: + self.update(modules) + + @_copy_to_script_wrapper + def __getitem__(self, key: str) -> Module: + return self._modules[key] + + def __setitem__(self, key: str, module: Module) -> None: + self.add_module(key, module) + + def __delitem__(self, key: str) -> None: + del self._modules[key] + + @_copy_to_script_wrapper + def __len__(self) -> int: + return len(self._modules) + + @_copy_to_script_wrapper + def __iter__(self) -> Iterator[str]: + return iter(self._modules) + + @_copy_to_script_wrapper + def __contains__(self, key: str) -> bool: + return key in self._modules + + def clear(self) -> None: + """Remove all items from the ModuleDict.""" + self._modules.clear() + + def pop(self, key: str) -> Module: + r"""Remove key from the ModuleDict and return its module. + + Args: + key (str): key to pop from the ModuleDict + """ + v = self[key] + del self[key] + return v + + @_copy_to_script_wrapper + def keys(self) -> Iterable[str]: + r"""Return an iterable of the ModuleDict keys.""" + return self._modules.keys() + + @_copy_to_script_wrapper + def items(self) -> Iterable[Tuple[str, Module]]: + r"""Return an iterable of the ModuleDict key/value pairs.""" + return self._modules.items() + + @_copy_to_script_wrapper + def values(self) -> Iterable[Module]: + r"""Return an iterable of the ModuleDict values.""" + return self._modules.values() + + def update(self, modules: Mapping[str, Module]) -> None: + r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys. + + .. note:: + If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`, + or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleDict.update should be called with an " + "iterable of key/value pairs, but got " + type(modules).__name__ + ) + + if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): + for key, module in modules.items(): + self[key] = module + else: + # modules here can be a list with two items + for j, m in enumerate(modules): + if not isinstance(m, container_abcs.Iterable): + raise TypeError( + "ModuleDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(m).__name__ + ) + if not len(m) == 2: + raise ValueError( + "ModuleDict update sequence element " + "#" + str(j) + " has length " + str(len(m)) + "; 2 is required" + ) + # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] + # that's too cumbersome to type correctly with overloads, so we add an ignore here + self[m[0]] = m[1] # type: ignore[assignment] + + # remove forward alltogether to fallback on Module's _forward_unimplemented + + +class ParameterList(Module): + r"""Holds parameters in a list. + + :class:`~torch.nn.ParameterList` can be used like a regular Python + list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered, + and will be visible by all :class:`~torch.nn.Module` methods. + + Note that the constructor, assigning an element of the list, the + :meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend` + method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`. + + Args: + parameters (iterable, optional): an iterable of elements to add to the list. + + Example:: + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)]) + + def forward(self, x): + # ParameterList can act as an iterable, or be indexed using ints + for i, p in enumerate(self.params): + x = self.params[i // 2].mm(x) + p.mm(x) + return x + """ + + def __init__(self, values: Optional[Iterable[Any]] = None) -> None: + super().__init__() + self._size = 0 + if values is not None: + self += values + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules.""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError(f"index {idx} is out of range") + if idx < 0: + idx += len(self) + return str(idx) + + @overload + def __getitem__(self, idx: int) -> Any: + ... + + @overload + def __getitem__(self: T, idx: slice) -> T: + ... + + def __getitem__(self, idx): + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + out = self.__class__() + for i in range(start, stop, step): + out.append(self[i]) + return out + else: + idx = self._get_abs_string_index(idx) + return getattr(self, str(idx)) + + def __setitem__(self, idx: int, param: Any) -> None: + # Note that all other function that add an entry to the list part of + # the ParameterList end up here. So this is the only place where we need + # to wrap things into Parameter if needed. + # Objects added via setattr() are not in the list part and thus won't + # call into this function. + idx = self._get_abs_string_index(idx) + if isinstance(param, torch.Tensor) and not isinstance(param, Parameter): + param = Parameter(param) + return setattr(self, str(idx), param) + + def __len__(self) -> int: + return self._size + + def __iter__(self) -> Iterator[Any]: + return iter(self[i] for i in range(len(self))) + + def __iadd__(self, parameters: Iterable[Any]) -> Self: + return self.extend(parameters) + + def __dir__(self): + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def append(self, value: Any) -> "ParameterList": + """Append a given value at the end of the list. + + Args: + value (Any): value to append + """ + new_idx = len(self) + self._size += 1 + self[new_idx] = value + return self + + def extend(self, values: Iterable[Any]) -> Self: + """Append values from a Python iterable to the end of the list. + + Args: + values (iterable): iterable of values to append + """ + # Tensor is an iterable but we never want to unpack it here + if not isinstance(values, container_abcs.Iterable) or isinstance( + values, torch.Tensor + ): + raise TypeError( + "ParameterList.extend should be called with an " + "iterable, but got " + type(values).__name__ + ) + for value in values: + self.append(value) + return self + + def extra_repr(self) -> str: + child_lines = [] + for k, p in enumerate(self): + if isinstance(p, torch.Tensor): + size_str = "x".join(str(size) for size in p.size()) + if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: + device_str = f" ({p.device})" + else: + device_str = "" + parastr = "{} containing: [{} of size {}{}]".format( + "Parameter" if isinstance(p, Parameter) else "Tensor", + p.dtype, + size_str, + device_str, + ) + child_lines.append(" (" + str(k) + "): " + parastr) + else: + child_lines.append( + " (" + str(k) + "): Object of type: " + type(p).__name__ + ) + + tmpstr = "\n".join(child_lines) + return tmpstr + + def __call__(self, *args, **kwargs): + raise RuntimeError("ParameterList should not be called.") + + +class ParameterDict(Module): + r"""Holds parameters in a dictionary. + + ParameterDict can be indexed like a regular Python dictionary, but Parameters it + contains are properly registered, and will be visible by all Module methods. + Other objects are treated as would be done by a regular Python dictionary + + :class:`~torch.nn.ParameterDict` is an **ordered** dictionary. + :meth:`~torch.nn.ParameterDict.update` with other unordered mapping + types (e.g., Python's plain ``dict``) does not preserve the order of the + merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict` + will preserve their ordering. + + Note that the constructor, assigning an element of the dictionary and the + :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into + :class:`~torch.nn.Parameter`. + + Args: + values (iterable, optional): a mapping (dictionary) of + (string : Any) or an iterable of key-value pairs + of type (string, Any) + + Example:: + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.params = nn.ParameterDict({ + 'left': nn.Parameter(torch.randn(5, 10)), + 'right': nn.Parameter(torch.randn(5, 10)) + }) + + def forward(self, x, choice): + x = self.params[choice].mm(x) + return x + """ + + def __init__(self, parameters: Any = None) -> None: + super().__init__() + self._keys: Dict[str, None] = {} + if parameters is not None: + self.update(parameters) + + def _key_to_attr(self, key: str) -> str: + if not isinstance(key, str): + raise TypeError( + "Index given to ParameterDict cannot be used as a key as it is " + f"not a string (type is '{type(key).__name__}'). Open an issue on " + "github if you need non-string keys." + ) + else: + # Use the key as-is so that `.named_parameters()` returns the right thing + return key + + def __getitem__(self, key: str) -> Any: + attr = self._key_to_attr(key) + return getattr(self, attr) + + def __setitem__(self, key: str, value: Any) -> None: + # Note that all other function that add an entry to the dictionary part of + # the ParameterDict end up here. So this is the only place where we need + # to wrap things into Parameter if needed. + # Objects added via setattr() are not in the dictionary part and thus won't + # call into this function. + self._keys[key] = None + attr = self._key_to_attr(key) + if isinstance(value, torch.Tensor) and not isinstance(value, Parameter): + value = Parameter(value) + setattr(self, attr, value) + + def __delitem__(self, key: str) -> None: + del self._keys[key] + attr = self._key_to_attr(key) + delattr(self, attr) + + def __len__(self) -> int: + return len(self._keys) + + def __iter__(self) -> Iterator[str]: + return iter(self._keys) + + def __reversed__(self) -> Iterator[str]: + return reversed(list(self._keys)) + + def copy(self) -> "ParameterDict": + """Return a copy of this :class:`~torch.nn.ParameterDict` instance.""" + # We have to use an OrderedDict because the ParameterDict constructor + # behaves differently on plain dict vs OrderedDict + return ParameterDict(OrderedDict((k, self[k]) for k in self._keys)) + + def __contains__(self, key: str) -> bool: + return key in self._keys + + def setdefault(self, key: str, default: Optional[Any] = None) -> Any: + """Set the default for a key in the Parameterdict. + + If key is in the ParameterDict, return its value. + If not, insert `key` with a parameter `default` and return `default`. + `default` defaults to `None`. + + Args: + key (str): key to set default for + default (Any): the parameter set to the key + """ + if key not in self: + self[key] = default + return self[key] + + def clear(self) -> None: + """Remove all items from the ParameterDict.""" + for k in self._keys.copy(): + del self[k] + + def pop(self, key: str) -> Any: + r"""Remove key from the ParameterDict and return its parameter. + + Args: + key (str): key to pop from the ParameterDict + """ + v = self[key] + del self[key] + return v + + def popitem(self) -> Tuple[str, Any]: + """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict.""" + k, _ = self._keys.popitem() + # We need the key in the _keys to be able to access/del + self._keys[k] = None + val = self[k] + del self[k] + return k, val + + def get(self, key: str, default: Optional[Any] = None) -> Any: + r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not. + + Args: + key (str): key to get from the ParameterDict + default (Parameter, optional): value to return if key not present + """ + return self[key] if key in self else default + + def fromkeys( + self, keys: Iterable[str], default: Optional[Any] = None + ) -> "ParameterDict": + r"""Return a new ParameterDict with the keys provided. + + Args: + keys (iterable, string): keys to make the new ParameterDict from + default (Parameter, optional): value to set for all keys + """ + return ParameterDict((k, default) for k in keys) + + def keys(self) -> Iterable[str]: + r"""Return an iterable of the ParameterDict keys.""" + return self._keys.keys() + + def items(self) -> Iterable[Tuple[str, Any]]: + r"""Return an iterable of the ParameterDict key/value pairs.""" + return ((k, self[k]) for k in self._keys) + + def values(self) -> Iterable[Any]: + r"""Return an iterable of the ParameterDict values.""" + return (self[k] for k in self._keys) + + def update(self, parameters: Union[Mapping[str, Any], "ParameterDict"]) -> None: + r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys. + + .. note:: + If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + parameters (iterable): a mapping (dictionary) from string to + :class:`~torch.nn.Parameter`, or an iterable of + key-value pairs of type (string, :class:`~torch.nn.Parameter`) + """ + if not isinstance(parameters, container_abcs.Iterable): + raise TypeError( + "ParametersDict.update should be called with an " + "iterable of key/value pairs, but got " + type(parameters).__name__ + ) + + if isinstance(parameters, (OrderedDict, ParameterDict)): + for key, parameter in parameters.items(): + self[key] = parameter + elif isinstance(parameters, container_abcs.Mapping): + for key, parameter in sorted(parameters.items()): + self[key] = parameter + else: + for j, p in enumerate(parameters): + if not isinstance(p, container_abcs.Iterable): + raise TypeError( + "ParameterDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(p).__name__ + ) + if not len(p) == 2: + raise ValueError( + "ParameterDict update sequence element " + "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" + ) + # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment + self[p[0]] = p[1] # type: ignore[assignment] + + def extra_repr(self) -> str: + child_lines = [] + for k, p in self.items(): + if isinstance(p, torch.Tensor): + size_str = "x".join(str(size) for size in p.size()) + if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: + device_str = f" ({p.device})" + else: + device_str = "" + parastr = "{} containing: [{} of size {}{}]".format( + "Parameter" if isinstance(p, Parameter) else "Tensor", + torch.typename(p), + size_str, + device_str, + ) + child_lines.append(" (" + str(k) + "): " + parastr) + else: + child_lines.append( + " (" + str(k) + "): Object of type: " + type(p).__name__ + ) + tmpstr = "\n".join(child_lines) + return tmpstr + + def __call__(self, input): + raise RuntimeError("ParameterDict should not be called.") + + def __or__(self, other: "ParameterDict") -> "ParameterDict": + copy = self.copy() + copy.update(other) + return copy + + def __ror__(self, other: "ParameterDict") -> "ParameterDict": + copy = other.copy() + copy.update(self) + return copy + + def __ior__(self, other: "ParameterDict") -> Self: + self.update(other) + return self