Source code for shapenet.scripts.predict_from_net


[docs]def predict(): """ Predicts file directory with network specified by files to output path """ import numpy as np import torch from tqdm import tqdm import os from matplotlib import pyplot as plt from ..utils import Config from ..layer import HomogeneousShapeLayer from ..networks import SingleShapeNetwork from shapedata.single_shape import SingleShapeDataProcessing, \ SingleShapeSingleImage2D from shapedata.io import pts_exporter import argparse parser = argparse.ArgumentParser() parser.add_argument("-v", "--visualize", action="store_true", help="If Flag is specified, results will be plotted") parser.add_argument("-d", "--in_path", type=str, help="Input Data Dir") parser.add_argument("-s", "--out_path", default="./outputs", type=str, help="Output Data Dir") parser.add_argument("-w", "--weight_file", type=str, help="Model Weights") parser.add_argument("-c", "--config_file", type=str, help="Configuration") args = parser.parse_args() config = Config() config_dict = config(os.path.abspath(args.config_file)) try: net = torch.jit.load(os.path.abspath(args.weight_file)) net.eval() net.cpu() except RuntimeError: net_layer = HomogeneousShapeLayer if config_dict["training"].pop("mixed_prec", False): try: from apex import amp amp.init() except: pass shapes = np.load(os.path.abspath(config_dict["layer"].pop("pca_path")) )["shapes"][:config_dict["layer"].pop("num_shape_params") + 1] net = SingleShapeNetwork( net_layer, {"shapes": shapes, **config_dict["layer"]}, img_size=config_dict["data"]["img_size"], **config_dict["network"]) state = torch.load(os.path.abspath(args.weight_file)) try: net.load_state_dict(state["state_dict"]["model"]) except KeyError: try: net.load_state_dict(state["model"]) except KeyError: net.load_state_dict(state) net = net.to("cpu") net = net.eval() data = SingleShapeDataProcessing._get_files( os.path.abspath(args.in_path), extensions=[".png", ".jpg"]) def process_sample(sample, img_size, net, device, crop=0.1): lmk_bounds = sample.get_landmark_bounds(sample.lmk) min_y, min_x, max_y, max_x = lmk_bounds range_x = max_x - min_x range_y = max_y - min_y max_range = max(range_x, range_y) * (1 + crop) center_x = min_x + range_x / 2 center_y = min_y + range_y / 2 tmp = sample.crop(center_y - max_range / 2, center_x - max_range / 2, center_y + max_range / 2, center_x + max_range / 2) img_tensor = torch.from_numpy( tmp.to_grayscale().resize((img_size, img_size)).img.transpose(2, 0, 1) ).to(torch.float).unsqueeze(0).to(device) pred = net(img_tensor).cpu().numpy()[0] pred = pred * np.array([max_range / img_size, max_range / img_size]) pred = pred + np.asarray([center_y - max_range / 2, center_x - max_range / 2]) return pred device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): if torch.cuda.is_available(): net = net.cuda() if args.visualize: pred_path = os.path.join(os.path.abspath(args.out_path), "pred") vis_path = os.path.join(os.path.abspath(args.out_path), "visualization") os.makedirs(vis_path, exist_ok=True) else: pred_path = os.path.abspath(args.out_path) os.makedirs(pred_path, exist_ok=True) for idx, file in enumerate(tqdm(data)): _data = SingleShapeSingleImage2D.from_files(file) pred = process_sample(_data, img_size=config_dict["data"]["img_size"], net=net, device=device) fname = os.path.split(_data.img_file)[-1].rsplit(".", 1)[0] if args.visualize: view_kwargs = {} if _data.is_gray: view_kwargs["cmap"] = "gray" fig = _data.view(True, **view_kwargs) plt.gca().scatter(pred[:, 1], pred[:, 0], s=5, c="C1") plt.gca().legend(["GT", "Pred"]) plt.gcf().savefig(os.path.join(vis_path, fname + ".png")) plt.close() _data.save(pred_path, fname, "PTS") pts_exporter(pred, os.path.join(pred_path, fname + "_pred.pts"))
if __name__ == '__main__': predict()