Source code for workflow.torch.module_compose

import torch
import torch.nn as nn


[docs]class ModuleCompose(nn.Module): ''' ``ModuleCompose`` behaves like an extended ``torch.nn.Sequential`` that also allows: - vanilla functions - expands tuples to the next function's arguments - specify a module and a wrapping function as seen in the example below .. code-block:: python from torch import nn from workflow.torch import ModuleCompose ModuleCompose( nn.Conv2d(3, 32), nn.Conv2d(32, 16), (nn.Conv2d(16, 16), lambda conv, x: conv(x) + x), lambda x: (x, x * 2), lambda x, x2: x + x2, ) ''' def __init__(self, *modules_and_functions): super().__init__() self.modules_and_functions = modules_and_functions self.module_list = nn.ModuleList([ module_or_function for module_or_function in self.modules_and_functions if isinstance(module_or_function, nn.Module) ] + [ module_or_function[0] for module_or_function in self.modules_and_functions if ( type(module_or_function) is tuple and isinstance(module_or_function[0], nn.Module) ) ]) self.parameter_list = nn.ParameterList([ parameter for parameter in self.modules_and_functions if isinstance(parameter, nn.Parameter) ] + [ parameter_or_tuple[0] for parameter_or_tuple in self.modules_and_functions if ( type(parameter_or_tuple) is tuple and isinstance(parameter_or_tuple[0], nn.Parameter) ) ])
[docs] def forward(self, *x): for module_or_function in self.modules_and_functions: if type(module_or_function) is tuple: module, fn = module_or_function if type(x) is tuple: x = fn(module, *x) else: x = fn(module, x) else: if type(x) is tuple: x = module_or_function(*x) else: x = module_or_function(x) return x
[docs] @torch.no_grad() def debug(self, x): for index, module_or_function in enumerate(self.modules_and_functions): if type(module_or_function) is tuple: module, fn = module_or_function if isinstance(module, nn.Module): n_parameters = sum( [p.shape.numel() for p in module.parameters()] ) n_parameters_postfix = f' n_parameters: {n_parameters}' else: n_parameters_postfix = '' print_intermediate(index, x, n_parameters_postfix) if type(x) is tuple: x = fn(module, *x) else: x = fn(module, x) else: if isinstance(module_or_function, nn.Module): n_parameters = sum([ p.shape.numel() for p in module_or_function.parameters() ]) n_parameters_postfix = f' n_parameters: {n_parameters}' else: n_parameters_postfix = '' print_intermediate(index, x, n_parameters_postfix) if type(x) is tuple: x = module_or_function(*x) else: x = module_or_function(x) return x
def print_intermediate(index, x, postfix): if type(x) is tuple: if hasattr(x[0], 'shape'): representation = f'shape: {[y.shape for y in x]}' else: representation = f'type: {[type(y) for y in x]}' else: if hasattr(x, 'shape'): representation = f'shape: {x.shape}' else: representation = f'type: {type(x)}' print(f'index: {index}, {representation}' + postfix) def test_module_compose(): import numpy as np class Example: def __init__(self, data): self.data = data model = ModuleCompose( lambda examples: torch.stack([ torch.from_numpy(example.data).float() for example in examples ]), nn.Conv2d(3, 32, 5), lambda x: x.mean(dim=(-1, -2)), lambda x: x.view(x.size(0), -1), nn.Linear(32, 1), ) batch = [Example(np.random.randn(3, 32, 32)) for i in range(8)] model.debug(batch)