Copy-and-Hallucinate Network¶
Module containing the implementation of the Copy-and-Hallucinate Network (CHN) and the methods required to train it using PyTorch Lightning.
- class master_thesis.model_chn.CHN(model_vgg, model_aligner, **kwargs)¶
Bases:
pytorch_lightning.core.lightning.LightningModule
Implementation of the Copy-and-Hallucinate Network (CHN).
- LOSSES_NAMES = ['loss_nh', 'loss_vh', 'loss_nvh', 'loss_perceptual', 'loss_grad']¶
- forward(x_target, v_target, x_refs_aligned, v_refs_aligned, v_maps)¶
Forward pass through the Copy-and-Hallucinate Network (CHN).
- Parameters
x_target – tensor of size (B,C,H,W) containing the frame to inpaint.
v_target – tensor of size (B,1,H,W) containing the frame to
inpaint. (visibility map of the frame to) –
x_refs_aligned – tensor of size (B,C,F,H,W) containing the
frames. (visibility maps of the reference) –
v_refs_aligned – tensor of size (B,C,F,H,W) containing the
frames. –
v_maps –
- Returns
A tuple containing both the raw output of the network and the composition of ground-truth background and inpainted hole.
- configure_optimizers()¶
Configures the optimizer used to train the Copy-and-Hallucinate Network (CHN).
- Returns
Configured optimizer object.
- on_epoch_start()¶
Called when either of train/val/test epoch begins.
- training_step(batch, batch_idx)¶
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters
batch (
Tensor
| (Tensor
, …) | [Tensor
, …]) – The output of yourDataLoader
. A tensor, tuple or list.batch_idx (
int
) – Integer displaying index of this batchoptimizer_idx (
int
) – When using multiple optimizers, this argument will also be present.hiddens (
Any
) – Passed in if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
- Returns
Any of.
Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
None
- Training will skip to the next batch. This is only for automatic optimization.This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
If you define multiple optimizers, this step will be called with an additional
optimizer_idx
parameter.# Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder ... if optimizer_idx == 1: # do training_step with decoder ...
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) loss = ... return {"loss": loss, "hiddens": hiddens}
Note
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
- align(x_target, m_target, x_ref, m_ref)¶
Aligns a target frame with respect to a reference frame using the selected network aligner.
- Parameters
x_target –
m_target –
x_ref –
m_ref –
Returns:
- copy_and_hallucinate(x_target, m_target, x_ref_aligned, v_ref_aligned, v_map)¶
- inpaint(x, m)¶
- compute_loss(y_target, v_target, y_hat, y_hat_comp, v_map)¶
- static get_indexes(size)¶
- training: bool¶
- class master_thesis.model_chn.RRDBNet(in_nc, out_nc, nb=10, nf=64, gc=32)¶
Bases:
torch.nn.modules.module.Module
Implementation of the Copy-and-Hallucinate Network (CHN).
- forward(x)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class master_thesis.model_chn.RRDB(nf, gc=32)¶
Bases:
torch.nn.modules.module.Module
Residual in Residual Dense Block
- forward(x)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class master_thesis.model_chn.ResidualDenseBlock5C(nf=64, gc=32, bias=True)¶
Bases:
torch.nn.modules.module.Module
- forward(x)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶