Source code for imaginaire.generators.gancraft_base

# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import functools
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from imaginaire.layers import Conv2dBlock, LinearBlock
from imaginaire.model_utils.gancraft.layers import AffineMod, ModLinear
import imaginaire.model_utils.gancraft.mc_utils as mc_utils
import imaginaire.model_utils.gancraft.voxlib as voxlib
from imaginaire.utils.distributed import master_only_print as print
[docs]class RenderMLP(nn.Module): r""" MLP with affine modulation.""" def __init__(self, in_channels, style_dim, viewdir_dim, mask_dim=680, out_channels_s=1, out_channels_c=3, hidden_channels=256, use_seg=True): super(RenderMLP, self).__init__() self.use_seg = use_seg if self.use_seg: self.fc_m_a = nn.Linear(mask_dim, hidden_channels, bias=False) self.fc_viewdir = None if viewdir_dim > 0: self.fc_viewdir = nn.Linear(viewdir_dim, hidden_channels, bias=False) self.fc_1 = nn.Linear(in_channels, hidden_channels) self.fc_2 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) self.fc_3 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) self.fc_4 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) self.fc_sigma = nn.Linear(hidden_channels, out_channels_s) if viewdir_dim > 0: self.fc_5 = nn.Linear(hidden_channels, hidden_channels, bias=False) self.mod_5 = AffineMod(hidden_channels, style_dim, mod_bias=True) else: self.fc_5 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) self.fc_6 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) self.fc_out_c = nn.Linear(hidden_channels, out_channels_c) self.act = nn.LeakyReLU(negative_slope=0.2)
[docs] def forward(self, x, raydir, z, m): r""" Forward network Args: x (N x H x W x M x in_channels tensor): Projected features. raydir (N x H x W x 1 x viewdir_dim tensor): Ray directions. z (N x style_dim tensor): Style codes. m (N x H x W x M x mask_dim tensor): One-hot segmentation maps. """ b, h, w, n, _ = x.size() z = z[:, None, None, None, :] f = self.fc_1(x) if self.use_seg: f = f + self.fc_m_a(m) # Common MLP f = self.act(f) f = self.act(self.fc_2(f, z)) f = self.act(self.fc_3(f, z)) f = self.act(self.fc_4(f, z)) # Sigma MLP sigma = self.fc_sigma(f) # Color MLP if self.fc_viewdir is not None: f = self.fc_5(f) f = f + self.fc_viewdir(raydir) f = self.act(self.mod_5(f, z)) else: f = self.act(self.fc_5(f, z)) f = self.act(self.fc_6(f, z)) c = self.fc_out_c(f) return sigma, c
[docs]class StyleMLP(nn.Module): r"""MLP converting style code to intermediate style representation.""" def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True, output_act=True): super(StyleMLP, self).__init__() self.normalize_input = normalize_input self.output_act = output_act fc_layers = [] fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True)) for i in range(num_layers-1): fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True)) self.fc_layers = nn.ModuleList(fc_layers) self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True) if leaky_relu: self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) else: self.act = functools.partial(F.relu, inplace=True)
[docs] def forward(self, z): r""" Forward network Args: z (N x style_dim tensor): Style codes. """ if self.normalize_input: z = F.normalize(z, p=2, dim=-1) for fc_layer in self.fc_layers: z = self.act(fc_layer(z)) z = self.fc_out(z) if self.output_act: z = self.act(z) return z
[docs]class SKYMLP(nn.Module): r"""MLP converting ray directions to sky features.""" def __init__(self, in_channels, style_dim, out_channels_c=3, hidden_channels=256, leaky_relu=True): super(SKYMLP, self).__init__() self.fc_z_a = nn.Linear(style_dim, hidden_channels, bias=False) self.fc1 = nn.Linear(in_channels, hidden_channels) self.fc2 = nn.Linear(hidden_channels, hidden_channels) self.fc3 = nn.Linear(hidden_channels, hidden_channels) self.fc4 = nn.Linear(hidden_channels, hidden_channels) self.fc5 = nn.Linear(hidden_channels, hidden_channels) self.fc_out_c = nn.Linear(hidden_channels, out_channels_c) if leaky_relu: self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) else: self.act = functools.partial(F.relu, inplace=True)
[docs] def forward(self, x, z): r"""Forward network Args: x (... x in_channels tensor): Ray direction embeddings. z (... x style_dim tensor): Style codes. """ z = self.fc_z_a(z) while z.dim() < x.dim(): z = z.unsqueeze(1) y = self.act(self.fc1(x) + z) y = self.act(self.fc2(y)) y = self.act(self.fc3(y)) y = self.act(self.fc4(y)) y = self.act(self.fc5(y)) c = self.fc_out_c(y) return c
[docs]class RenderCNN(nn.Module): r"""CNN converting intermediate feature map to final image.""" def __init__(self, in_channels, style_dim, hidden_channels=256, leaky_relu=True): super(RenderCNN, self).__init__() self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels) self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0) self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0) if leaky_relu: self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) else: self.act = functools.partial(F.relu, inplace=True)
[docs] def modulate(self, x, w, b): w = w[..., None, None] b = b[..., None, None] return x * (w+1) + b
[docs] def forward(self, x, z): r"""Forward network. Args: x (N x in_channels x H x W tensor): Intermediate feature map z (N x style_dim tensor): Style codes. """ z = self.fc_z_cond(z) adapt = torch.chunk(z, 2 * 2, dim=-1) y = self.act(self.conv1(x)) y = y + self.conv2b(self.act(self.conv2a(y))) y = self.act(self.modulate(y, adapt[0], adapt[1])) y = y + self.conv3b(self.act(self.conv3a(y))) y = self.act(self.modulate(y, adapt[2], adapt[3])) y = y + self.conv4b(self.act(self.conv4a(y))) y = self.act(y) y = self.conv4(y) return y
[docs]class StyleEncoder(nn.Module): r"""Style Encoder constructor. Args: style_enc_cfg (obj): Style encoder definition file. """ def __init__(self, style_enc_cfg): super(StyleEncoder, self).__init__() input_image_channels = style_enc_cfg.input_image_channels num_filters = style_enc_cfg.num_filters kernel_size = style_enc_cfg.kernel_size padding = int(np.ceil((kernel_size - 1.0) / 2)) style_dims = style_enc_cfg.style_dims weight_norm_type = style_enc_cfg.weight_norm_type self.no_vae = getattr(style_enc_cfg, 'no_vae', False) activation_norm_type = 'none' nonlinearity = 'leakyrelu' base_conv2d_block = \ functools.partial(Conv2dBlock, kernel_size=kernel_size, stride=2, padding=padding, weight_norm_type=weight_norm_type, activation_norm_type=activation_norm_type, # inplace_nonlinearity=True, nonlinearity=nonlinearity) self.layer1 = base_conv2d_block(input_image_channels, num_filters) self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims) if not self.no_vae: self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
[docs] def forward(self, input_x): r"""SPADE Style Encoder forward. Args: input_x (N x 3 x H x W tensor): input images. Returns: mu (N x C tensor): Mean vectors. logvar (N x C tensor): Log-variance vectors. z (N x C tensor): Style code vectors. """ if input_x.size(2) != 256 or input_x.size(3) != 256: input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') x = self.layer1(input_x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.layer5(x) x = self.layer6(x) x = x.view(x.size(0), -1) mu = self.fc_mu(x) if not self.no_vae: logvar = self.fc_var(x) std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = eps.mul(std) + mu else: z = mu logvar = torch.zeros_like(mu) return mu, logvar, z
[docs]class Base3DGenerator(nn.Module): r"""Minecraft 3D generator constructor. Args: gen_cfg (obj): Generator definition part of the yaml config file. data_cfg (obj): Data definition part of the yaml config file. """ def __init__(self, gen_cfg, data_cfg): super(Base3DGenerator, self).__init__() print('Base3DGenerator initialization.') # ---------------------- Main Network ------------------------ # Exclude some of the features from positional encoding self.pe_no_pe_feat_dim = getattr(gen_cfg, 'pe_no_pe_feat_dim', 0) # blk_feat passes through PE input_dim = (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)*(gen_cfg.pe_lvl_feat*2) + self.pe_no_pe_feat_dim if (gen_cfg.pe_incl_orig_feat): input_dim += (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim) print('[Base3DGenerator] Expected input dimensions: ', input_dim) self.input_dim = input_dim self.mlp_model_kwargs = gen_cfg.mlp_model_kwargs self.pe_lvl_localcoords = getattr(gen_cfg, 'pe_lvl_localcoords', 0) if self.pe_lvl_localcoords > 0: self.mlp_model_kwargs['poscode_dim'] = self.pe_lvl_localcoords * 2 * 3 # Set pe_lvl_raydir=0 and pe_incl_orig_raydir=False to disable view direction input input_dim_viewdir = 3*(gen_cfg.pe_lvl_raydir*2) if (gen_cfg.pe_incl_orig_raydir): input_dim_viewdir += 3 print('[Base3DGenerator] Expected viewdir input dimensions: ', input_dim_viewdir) self.input_dim_viewdir = input_dim_viewdir self.pe_params = [gen_cfg.pe_lvl_feat, gen_cfg.pe_incl_orig_feat, gen_cfg.pe_lvl_raydir, gen_cfg.pe_incl_orig_raydir] # Style input dimension style_dims = gen_cfg.style_dims self.style_dims = style_dims interm_style_dims = getattr(gen_cfg, 'interm_style_dims', style_dims) self.interm_style_dims = interm_style_dims # ---------------------- Style MLP -------------------------- self.style_net = globals()[gen_cfg.stylenet_model]( style_dims, interm_style_dims, **gen_cfg.stylenet_model_kwargs) # number of output channels for MLP (before blending) final_feat_dim = getattr(gen_cfg, 'final_feat_dim', 16) self.final_feat_dim = final_feat_dim # ----------------------- Sky Network ------------------------- sky_input_dim_base = 3 # Dedicated sky network input dimensions sky_input_dim = sky_input_dim_base*(gen_cfg.pe_lvl_raydir_sky*2) if (gen_cfg.pe_incl_orig_raydir_sky): sky_input_dim += sky_input_dim_base print('[Base3DGenerator] Expected sky input dimensions: ', sky_input_dim) self.pe_params_sky = [gen_cfg.pe_lvl_raydir_sky, gen_cfg.pe_incl_orig_raydir_sky] self.sky_net = SKYMLP(sky_input_dim, style_dim=interm_style_dims, out_channels_c=final_feat_dim) # ----------------------- Style Encoder ------------------------- style_enc_cfg = getattr(gen_cfg, 'style_enc', None) setattr(style_enc_cfg, 'input_image_channels', 3) setattr(style_enc_cfg, 'style_dims', gen_cfg.style_dims) self.style_encoder = StyleEncoder(style_enc_cfg) # ---------------------- Ray Caster ------------------------- self.num_blocks_early_stop = gen_cfg.num_blocks_early_stop self.num_samples = gen_cfg.num_samples self.sample_depth = gen_cfg.sample_depth self.coarse_deterministic_sampling = getattr(gen_cfg, 'coarse_deterministic_sampling', True) self.sample_use_box_boundaries = getattr(gen_cfg, 'sample_use_box_boundaries', True) # ---------------------- Blender ------------------------- self.raw_noise_std = getattr(gen_cfg, 'raw_noise_std', 0.0) self.dists_scale = getattr(gen_cfg, 'dists_scale', 0.25) self.clip_feat_map = getattr(gen_cfg, 'clip_feat_map', True) self.keep_sky_out = getattr(gen_cfg, 'keep_sky_out', False) self.keep_sky_out_avgpool = getattr(gen_cfg, 'keep_sky_out_avgpool', False) keep_sky_out_learnbg = getattr(gen_cfg, 'keep_sky_out_learnbg', False) self.sky_global_avgpool = getattr(gen_cfg, 'sky_global_avgpool', False) if self.keep_sky_out: self.sky_replace_color = None if keep_sky_out_learnbg: sky_replace_color = torch.zeros([final_feat_dim]) sky_replace_color.requires_grad = True self.sky_replace_color = torch.nn.Parameter(sky_replace_color) # ---------------------- render_cnn ------------------------- self.denoiser = RenderCNN(final_feat_dim, style_dim=interm_style_dims) self.pad = gen_cfg.pad
[docs] def get_param_groups(self, cfg_opt): print('[Generator] get_param_groups') if hasattr(cfg_opt, 'ignore_parameters'): print('[Generator::get_param_groups] [x]: ignored.') optimize_parameters = [] for k, x in self.named_parameters(): match = False for m in cfg_opt.ignore_parameters: if re.match(m, k) is not None: match = True print(' [x]', k) break if match is False: print(' [v]', k) optimize_parameters.append(x) else: optimize_parameters = self.parameters() param_groups = [] param_groups.append({'params': optimize_parameters}) if hasattr(cfg_opt, 'param_groups'): optimized_param_names = [] all_param_names = [k for k, v in self.named_parameters()] param_groups = [] for k, v in cfg_opt.param_groups.items(): print('[Generator::get_param_groups] Adding param group from config:', k, v) params = getattr(self, k) named_parameters = [k] if issubclass(type(params), nn.Module): named_parameters = [k+'.'+pname for pname, _ in params.named_parameters()] params = params.parameters() param_groups.append({'params': params, **v}) optimized_param_names.extend(named_parameters) print('[Generator::get_param_groups] UNOPTIMIZED PARAMETERS:\n ', set(all_param_names) - set(optimized_param_names)) return param_groups
def _forward_perpix_sub(self, blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot=None): r"""Forwarding the MLP. Args: blk_feats (K x C1 tensor): Sparse block features. worldcoord2 (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. raydirs_in (N x H x W x 1 x C2 tensor or None): ray direction embeddings. z (N x C3 tensor): Intermediate style vectors. mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps. Returns: net_out_s (N x H x W x L x 1 tensor): Opacities. net_out_c (N x H x W x L x C5 tensor): Color embeddings. """ proj_feature = voxlib.sparse_trilinear_interp_worldcoord( blk_feats, self.voxel.corner_t, worldcoord2, ign_zero=True) render_net_extra_kwargs = {} if self.pe_lvl_localcoords > 0: local_coords = torch.remainder(worldcoord2, 1.0) * 2.0 # Scale to [0, 2], as the positional encoding function doesn't have internal x2 local_coords[torch.isnan(local_coords)] = 0.0 local_coords = local_coords.contiguous() poscode = voxlib.positional_encoding(local_coords, self.pe_lvl_localcoords, -1, False) render_net_extra_kwargs['poscode'] = poscode if self.pe_params[0] == 0 and self.pe_params[1] is True: # no PE shortcut, saves ~400MB feature_in = proj_feature else: if self.pe_no_pe_feat_dim > 0: feature_in = voxlib.positional_encoding( proj_feature[..., :-self.pe_no_pe_feat_dim].contiguous(), self.pe_params[0], -1, self.pe_params[1]) feature_in = torch.cat([feature_in, proj_feature[..., -self.pe_no_pe_feat_dim:]], dim=-1) else: feature_in = voxlib.positional_encoding( proj_feature.contiguous(), self.pe_params[0], -1, self.pe_params[1]) net_out_s, net_out_c = self.render_net(feature_in, raydirs_in, z, mc_masks_onehot, **render_net_extra_kwargs) if self.raw_noise_std > 0.: noise = torch.randn_like(net_out_s) * self.raw_noise_std net_out_s = net_out_s + noise return net_out_s, net_out_c def _forward_perpix(self, blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z): r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features Args: blk_feats (K x C1 tensor): Sparse block features. voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels depth2 (N x 2 x H x W x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection. raydirs (N x H x W x 1 x 3 tensor): The direction of each ray. cam_ori_t (N x 3 tensor): Camera origins. z (N x C3 tensor): Intermediate style vectors. """ # Generate sky_mask; PE transform on ray direction. with torch.no_grad(): raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() if self.pe_params[2] == 0 and self.pe_params[3] is True: raydirs_in = raydirs_in elif self.pe_params[2] == 0 and self.pe_params[3] is False: # Not using raydir at all raydirs_in = None else: raydirs_in = voxlib.positional_encoding(raydirs_in, self.pe_params[2], -1, self.pe_params[3]) # sky_mask: when True, ray finally hits sky sky_mask = voxel_id[:, :, :, [-1], :] == 0 # sky_only_mask: when True, ray hits nothing but sky sky_only_mask = voxel_id[:, :, :, [0], :] == 0 with torch.no_grad(): # Random sample points along the ray num_samples = self.num_samples + 1 if self.sample_use_box_boundaries: num_samples = self.num_samples - self.num_blocks_early_stop # 10 samples per ray + 4 intersections - 2 rand_depth, new_dists, new_idx = mc_utils.sample_depth_batched( depth2, num_samples, deterministic=self.coarse_deterministic_sampling, use_box_boundaries=self.sample_use_box_boundaries, sample_depth=self.sample_depth) worldcoord2 = raydirs * rand_depth + cam_ori_t[:, None, None, None, :] # Generate per-sample segmentation label voxel_id_reduced = self.label_trans.mc2reduced(voxel_id, ign2dirt=True) mc_masks = torch.gather(voxel_id_reduced, -2, new_idx) # B 256 256 N 1 mc_masks = mc_masks.long() mc_masks_onehot = torch.zeros([mc_masks.size(0), mc_masks.size(1), mc_masks.size( 2), mc_masks.size(3), self.num_reduced_labels], dtype=torch.float, device=voxel_id.device) # mc_masks_onehot: [B H W Nlayer 680] mc_masks_onehot.scatter_(-1, mc_masks, 1.0) net_out_s, net_out_c = self._forward_perpix_sub(blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot) # Handle sky sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() sky_raydirs_in = voxlib.positional_encoding(sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1]) skynet_out_c = self.sky_net(sky_raydirs_in, z) # Blending weights = mc_utils.volum_rendering_relu(net_out_s, new_dists * self.dists_scale, dim=-2) # If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero. weights = weights * torch.logical_not(sky_only_mask).float() total_weights_raw = torch.sum(weights, dim=-2, keepdim=True) # 256 256 1 1 total_weights = total_weights_raw is_gnd = worldcoord2[..., [0]] <= 1.0 # Y X Z, [256, 256, 4, 3], nan < 1.0 == False is_gnd = is_gnd.any(dim=-2, keepdim=True) nosky_mask = torch.logical_or(torch.logical_not(sky_mask), is_gnd) nosky_mask = nosky_mask.float() # Avoid sky leakage sky_weight = 1.0-total_weights if self.keep_sky_out: # keep_sky_out_avgpool overrides sky_replace_color if self.sky_replace_color is None or self.keep_sky_out_avgpool: if self.keep_sky_out_avgpool: if hasattr(self, 'sky_avg'): sky_avg = self.sky_avg else: if self.sky_global_avgpool: sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True) else: skynet_out_c_nchw = skynet_out_c.permute(0, 4, 1, 2, 3).squeeze(-1) sky_avg = F.avg_pool2d(skynet_out_c_nchw, 31, stride=1, padding=15, count_include_pad=False) sky_avg = sky_avg.permute(0, 2, 3, 1).unsqueeze(-2) # print(sky_avg.shape) skynet_out_c = skynet_out_c * (1.0-nosky_mask) + sky_avg*(nosky_mask) else: sky_weight = sky_weight * (1.0-nosky_mask) else: skynet_out_c = skynet_out_c * (1.0-nosky_mask) + self.sky_replace_color*(nosky_mask) if self.clip_feat_map is True: # intermediate feature before blending & CNN rgbs = torch.clamp(net_out_c, -1, 1) + 1 rgbs_sky = torch.clamp(skynet_out_c, -1, 1) + 1 net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 net_out = net_out.squeeze(-2) net_out = net_out - 1 elif self.clip_feat_map is False: rgbs = net_out_c rgbs_sky = skynet_out_c net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 net_out = net_out.squeeze(-2) elif self.clip_feat_map == 'tanh': rgbs = torch.tanh(net_out_c) rgbs_sky = torch.tanh(skynet_out_c) net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 net_out = net_out.squeeze(-2) else: raise NotImplementedError return net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \ nosky_mask, sky_mask, sky_only_mask, new_idx def _forward_global(self, net_out, z): r"""Forward the CNN Args: net_out (N x C5 x H x W tensor): Intermediate feature maps. z (N x C3 tensor): Intermediate style vectors. Returns: fake_images (N x 3 x H x W tensor): Output image. fake_images_raw (N x 3 x H x W tensor): Output image before TanH. """ fake_images = net_out.permute(0, 3, 1, 2) fake_images_raw = self.denoiser(fake_images, z) fake_images = torch.tanh(fake_images_raw) return fake_images, fake_images_raw