Source code for shapenet.scripts.train_single_shapenet

# author: Justus Schock (justus.schock@rwth-aachen.de)


[docs]def train_shapenet(): """ Trains a single shapenet with config file from comandline arguments See Also -------- :class:`delira.training.PyTorchNetworkTrainer` """ import logging import numpy as np import torch from shapedata.single_shape import SingleShapeDataset from delira.training import PyTorchNetworkTrainer from ..utils import Config from ..layer import HomogeneousShapeLayer from ..networks import SingleShapeNetwork from delira.logging import TrixiHandler from trixi.logger import PytorchVisdomLogger from delira.training.callbacks import ReduceLROnPlateauCallbackPyTorch from delira.data_loading import BaseDataManager, RandomSampler, \ SequentialSampler import os import argparse parser = argparse.ArgumentParser() parser.add_argument("-c", "--config", type=str, help="Path to configuration file") parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_args() config = Config() config_dict = config(os.path.abspath(args.config)) shapes = np.load(os.path.abspath(config_dict["layer"].pop("pca_path")) )["shapes"][:config_dict["layer"].pop("num_shape_params") + 1] # layer_cls = HomogeneousShapeLayer net = SingleShapeNetwork( HomogeneousShapeLayer, {"shapes": shapes, **config_dict["layer"]}, img_size=config_dict["data"]["img_size"], **config_dict["network"]) num_params = 0 for param in net.parameters(): num_params += param.numel() if args.verbose: print("Number of Parameters: %d" % num_params) criterions = {"L1": torch.nn.L1Loss()} metrics = {"MSE": torch.nn.MSELoss()} mixed_prec = config_dict["training"].pop("mixed_prec", False) config_dict["training"]["save_path"] = os.path.abspath( config_dict["training"]["save_path"]) trainer = PyTorchNetworkTrainer( net, criterions=criterions, metrics=metrics, lr_scheduler_cls=ReduceLROnPlateauCallbackPyTorch, lr_scheduler_params=config_dict["scheduler"], optimizer_cls=torch.optim.Adam, optimizer_params=config_dict["optimizer"], mixed_precision=mixed_prec, **config_dict["training"]) if args.verbose: print(trainer.input_device) print("Load Data") dset_train = SingleShapeDataset( os.path.abspath(config_dict["data"]["train_path"]), config_dict["data"]["img_size"], config_dict["data"]["crop"], config_dict["data"]["landmark_extension_train"], cached=config_dict["data"]["cached"], rotate=config_dict["data"]["rotate_train"], random_offset=config_dict["data"]["offset_train"] ) if config_dict["data"]["test_path"]: dset_val = SingleShapeDataset( os.path.abspath(config_dict["data"]["test_path"]), config_dict["data"]["img_size"], config_dict["data"]["crop"], config_dict["data"]["landmark_extension_test"], cached=config_dict["data"]["cached"], rotate=config_dict["data"]["rotate_test"], random_offset=config_dict["data"]["offset_test"] ) else: dset_val = None mgr_train = BaseDataManager( dset_train, batch_size=config_dict["data"]["batch_size"], n_process_augmentation=config_dict["data"]["num_workers"], transforms=None, sampler_cls=RandomSampler ) mgr_val = BaseDataManager( dset_val, batch_size=config_dict["data"]["batch_size"], n_process_augmentation=config_dict["data"]["num_workers"], transforms=None, sampler_cls=SequentialSampler ) if args.verbose: print("Data loaded") if config_dict["logging"].pop("enable", False): logger_cls = PytorchVisdomLogger logging.basicConfig(level=logging.INFO, handlers=[ TrixiHandler( logger_cls, **config_dict["logging"]) ]) else: logging.basicConfig(level=logging.INFO, handlers=[logging.NullHandler()]) logger = logging.getLogger("Test Logger") logger.info("Start Training") if args.verbose: print("Start Training") trainer.train(config_dict["training"]["num_epochs"], mgr_train, mgr_val, config_dict["training"]["val_score_key"], val_score_mode='lowest')
if __name__ == '__main__': train_shapenet()