Source code for shapenet.layer.homogeneous_shape_layer

# author: Justus Schock (justus.schock@rwth-aachen.de)

import torch
from .shape_layer import ShapeLayer
from .homogeneous_transform_layer import HomogeneousTransformationLayer


[docs]class HomogeneousShapeLayer(torch.nn.Module): """ Module to Perform a Shape Prediction (including a global homogeneous transformation) """ def __init__(self, shapes, n_dims, use_cpp=False): """ Parameters ---------- shapes : np.ndarray shapes to construct a :class:`ShapeLayer` n_dims : int number of shape dimensions use_cpp : bool whether or not to use (experimental) C++ Implementation See Also -------- :class:`ShapeLayer` :class:`HomogeneousTransformationLayer` """ super().__init__() self._shape_layer = ShapeLayer(shapes, use_cpp) self._homogen_trafo = HomogeneousTransformationLayer(n_dims, use_cpp) self.register_buffer("_indices_shape_params", torch.arange(self._shape_layer.num_params)) self.register_buffer("_indices_homogen_params", torch.arange(self._shape_layer.num_params, self.num_params))
[docs] def forward(self, params: torch.Tensor): """ Performs the actual prediction Parameters ---------- params : :class:`torch.Tensor` input parameters Returns ------- :class:`torch.Tensor` predicted shape """ shape_params = params.index_select( dim=1, index=getattr(self, "_indices_shape_params") ) transformation_params = params.index_select( dim=1, index=getattr(self, "_indices_homogen_params") ) shapes = self._shape_layer(shape_params) transformed_shapes = self._homogen_trafo(shapes, transformation_params) return transformed_shapes
@property def num_params(self): """ Property to access these layer's number of parameters Returns ------- int number of parameters """ return self._shape_layer.num_params + self._homogen_trafo.num_params