Source code for shapenet.jit.abstract_network

# author: Justus Schock (justus.schock@rwth-aachen.de)

import torch
from abc import abstractmethod


[docs]class AbstractShapeNetwork(torch.jit.ScriptModule): """ Abstract JIT Network """ def __init__(self, **kwargs): super().__init__(optimize=True)
[docs] @staticmethod def norm_type_to_class(norm_type): norm_dict = {'instance': torch.nn.InstanceNorm2d, 'batch': torch.nn.BatchNorm2d} norm_class = norm_dict.get(norm_type, None) return norm_class
[docs]class AbstractFeatureExtractor(torch.jit.ScriptModule): """ Abstract Feature Extractor Class all further feature extractors should be derived from """ def __init__(self, in_channels, out_params, norm_class, p_dropout=0): """ Parameters ---------- in_channels : int number of input channels out_params : int number of outputs norm_class : Any Class implementing a normalization p_dropout : float dropout probability """ super().__init__() self.model = self._build_model(in_channels, out_params, norm_class, p_dropout)
[docs] @torch.jit.script_method def forward(self, input_batch): """ Feed batch through network Parameters ---------- input_batch : :class:`torch.Tensor` batch to feed through network Returns ------- :class:`torch.Tensor` extracted features """ return self.model(input_batch)
[docs] @staticmethod @abstractmethod def _build_model(in_channels, out_features, norm_class, p_dropout): """ Build the actual model structure Parameters ---------- in_channels : int number of input channels out_features : int number of outputs norm_class : Any class implementing a normalization p_dropout : float dropout probability Returns ------- :class:`torch.jit.ScriptModule` ensembled model """ raise NotImplementedError