Copy-and-Hallucinate Network

Module containing the implementation of the Copy-and-Hallucinate Network (CHN) and the methods required to train it using PyTorch Lightning.

class master_thesis.model_chn.CHN(model_vgg, model_aligner, **kwargs)

Bases: pytorch_lightning.core.lightning.LightningModule

Implementation of the Copy-and-Hallucinate Network (CHN).

LOSSES_NAMES = ['loss_nh', 'loss_vh', 'loss_nvh', 'loss_perceptual', 'loss_grad']
forward(x_target, v_target, x_refs_aligned, v_refs_aligned, v_maps)

Forward pass through the Copy-and-Hallucinate Network (CHN).

Parameters
  • x_target – tensor of size (B,C,H,W) containing the frame to inpaint.

  • v_target – tensor of size (B,1,H,W) containing the frame to

  • inpaint. (visibility map of the frame to) –

  • x_refs_aligned – tensor of size (B,C,F,H,W) containing the

  • frames. (visibility maps of the reference) –

  • v_refs_aligned – tensor of size (B,C,F,H,W) containing the

  • frames.

  • v_maps

Returns

A tuple containing both the raw output of the network and the composition of ground-truth background and inpainted hole.

configure_optimizers()

Configures the optimizer used to train the Copy-and-Hallucinate Network (CHN).

Returns

Configured optimizer object.

on_epoch_start()

Called when either of train/val/test epoch begins.

training_step(batch, batch_idx)

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters
Returns

Any of.

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'

  • None - Training will skip to the next batch. This is only for automatic optimization.

    This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.

# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
    if optimizer_idx == 0:
        # do training_step with encoder
        ...
    if optimizer_idx == 1:
        # do training_step with decoder
        ...

If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
    # hiddens are the hidden states from the previous truncated backprop step
    out, hiddens = self.lstm(data, hiddens)
    loss = ...
    return {"loss": loss, "hiddens": hiddens}

Note

The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.

align(x_target, m_target, x_ref, m_ref)

Aligns a target frame with respect to a reference frame using the selected network aligner.

Parameters
  • x_target

  • m_target

  • x_ref

  • m_ref

Returns:

copy_and_hallucinate(x_target, m_target, x_ref_aligned, v_ref_aligned, v_map)
inpaint(x, m)
compute_loss(y_target, v_target, y_hat, y_hat_comp, v_map)
static get_indexes(size)
training: bool
class master_thesis.model_chn.RRDBNet(in_nc, out_nc, nb=10, nf=64, gc=32)

Bases: torch.nn.modules.module.Module

Implementation of the Copy-and-Hallucinate Network (CHN).

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class master_thesis.model_chn.RRDB(nf, gc=32)

Bases: torch.nn.modules.module.Module

Residual in Residual Dense Block

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class master_thesis.model_chn.ResidualDenseBlock5C(nf=64, gc=32, bias=True)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool