Source code for shapenet.jit.shape_layer

# author: Justus Schock (justus.schock@rwth-aachen.de)

import numpy as np
import torch


[docs]class ShapeLayer(torch.jit.ScriptModule): def __init__(self, shapes, use_cpp=False): """ Parameters ---------- shapes : np.ndarray the shape components needed by the actual shape layer implementation use_cpp : bool whether to use cpp implementation or not (Currently only the python version is supported) """ super().__init__() self._layer = _ShapeLayerPy(shapes) assert not use_cpp, "Currently only the Python Version is supported"
[docs] @torch.jit.script_method def forward(self, shape_params: torch.Tensor): return self._layer(shape_params)
@property def num_params(self): return self._layer.num_params
class _ShapeLayerPy(torch.jit.ScriptModule): """ Python Implementation of Shape Layer """ def __init__(self, shapes): """ Parameters ---------- shapes : np.ndarray eigen shapes (obtained by PCA) """ super().__init__() self.register_buffer("_shape_mean", torch.from_numpy( shapes[0]).float().unsqueeze(0)) components = [] for i, _shape in enumerate(shapes[1:]): components.append(torch.from_numpy( _shape).float().unsqueeze(0)) component_tensor = torch.cat(components).unsqueeze(0) self.register_buffer("_shape_components", component_tensor) @torch.jit.script_method def forward(self, shape_params: torch.Tensor): """ Ensemble shape from parameters Parameters ---------- shape_params : :class:`torch.Tensor` shape parameters Returns ------- :class:`torch.Tensor` ensembled shape """ shapes = getattr(self, "_shape_mean").clone() shapes = shapes.expand(shape_params.size(0), shapes.size(1), shapes.size(2)) components = getattr(self, "_shape_components") components = components.expand(shape_params.size(0), components.size(1), components.size(2), components.size(3)) weighted_components = components.mul( shape_params.expand_as(components)) shapes = shapes.add(weighted_components.sum(dim=1)) return shapes @property def num_params(self): return getattr(self, "_shape_components").size(1)