Generative Models: Energy-Based Models, GANs, and VAEs

EE 641 - Unit 3

Dr. Brandon Franzke

Fall 2025

Introduction

Outline

Foundations & GANs

Energy-Based Models

  • Statistical physics foundations
  • Boltzmann machines and RBMs
  • Contrastive divergence training
  • Modern EBMs and score matching
  • Computational challenges

Generative Adversarial Networks

  • Minimax game formulation
  • Optimal discriminator derivation
  • Jensen-Shannon divergence connection
  • KL vs JS divergence trade-offs

GAN Training Dynamics

  • Vanishing gradients problem
  • Mode collapse and dropping
  • Nash equilibrium analysis
  • Non-saturating objectives

Training Stabilization

  • Label smoothing (0.7-1.0, 0.0-0.3)
  • Optimizer configurations (SGD/Adam)
  • Gradient penalties and spectral normalization
  • Production training pipelines

Wasserstein GANs

  • Earth mover’s distance
  • Lipschitz constraints
  • Gradient penalty (WGAN-GP)
  • Improved mode coverage

Architectures & VAEs

GAN Architectures

  • Transposed convolutions
  • DCGAN principles
  • Conditional GANs
  • Progressive growing
  • StyleGAN innovations
  • Pix2Pix and CycleGAN

Evaluation & Applications

  • Inception Score limitations
  • Fréchet Inception Distance (FID)
  • Sample size requirements
  • Non-vision applications

Variational Autoencoders

  • Latent variable models
  • Intractable posterior problem
  • Variational inference principle

ELBO & Training

  • Evidence lower bound derivation
  • KL divergence for Gaussians
  • Reparameterization trick
  • Numerical stability

VAE Variants

  • β-VAE for disentanglement
  • VQ-VAE with discrete codes
  • Hierarchical VAEs
  • Model comparisons

Reading List

  • [Score Matching] Y. Song and S. Ermon, “Generative modeling by estimating gradients of the data distribution,” in Advances in Neural Information Processing Systems, 2019, pp. 11918–11930.

  • [GAN] I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio, “Generative adversarial nets,” in Advances in Neural Information Processing Systems, 2014, pp. 2672–2680.

  • [GAN Review] I. Goodfellow, “NIPS 2016 tutorial: Generative adversarial networks,” arXiv preprint arXiv:1701.00160, 2016.

  • [GAN Theory] S. Arora and Y. Zhang, “Do GANs actually learn the distribution? An empirical study,” arXiv preprint arXiv:1706.08224, 2017.

  • [DCGAN] A. Radford, L. Metz, and S. Chintala, “Unsupervised representation learning with deep convolutional generative adversarial networks,” in International Conference on Learning Representations, 2016.

  • [WGAN] M. Arjovsky, S. Chintala, and L. Bottou, “Wasserstein generative adversarial networks,” in International Conference on Machine Learning, 2017, pp. 214–223.

  • [WGAN-GP] I. Gulrajani, F. Ahmed, M. Arjovsky, V. Dumoulin, and A. C. Courville, “Improved training of Wasserstein GANs,” in Advances in Neural Information Processing Systems, 2017, pp. 5767–5777.

  • [BigGAN] A. Brock, J. Donahue, and K. Simonyan, “Large scale GAN training for high fidelity natural image synthesis,” in International Conference on Learning Representations, 2019.

  • [StyleGAN] T. Karras, S. Laine, and T. Aila, “A style-based generator architecture for generative adversarial networks,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019, pp. 4401–4410.

  • [Pix2Pix] P. Isola, J.-Y. Zhu, T. Zhou, and A. A. Efros, “Image-to-image translation with conditional adversarial networks,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2017, pp. 1125–1134.

  • [VAE] D. P. Kingma and M. Welling, “Auto-encoding variational Bayes,” in International Conference on Learning Representations, 2014.

  • [VQ-VAE] A. van den Oord, O. Vinyals, and K. Kavukcuoglu, “Neural discrete representation learning,” in Advances in Neural Information Processing Systems, 2017, pp. 6306–6315.

Energy-Based Models

Statistical Physics Foundation

Boltzmann Distribution and Energy Functions

Energy defines probability

\[p(\mathbf{x}) = \frac{1}{Z} \exp\left(-\frac{E(\mathbf{x})}{T}\right)\]

where partition function: \[Z = \int \exp\left(-\frac{E(\mathbf{x})}{T}\right) d\mathbf{x}\]

Why this distribution?

Maximizes entropy \(S = -\sum_i p_i \log p_i\) subject to:

  • Fixed average energy: \(\langle E \rangle = U\)
  • Normalization: \(\sum_i p_i = 1\)

Result from Lagrange multipliers: \[p_i \propto \exp(-\beta E_i)\] where \(\beta = 1/T\) is inverse temperature

Temperature controls exploration-exploitation:

  • Low \(T\) → concentrates on energy minima
  • High \(T\) → broad exploration
  • \(T→0\) → delta function at global minimum
  • Score function: \(\nabla_{\mathbf{x}} \log p(\mathbf{x}) = -\frac{1}{T}\nabla_{\mathbf{x}} E(\mathbf{x})\)

Partition Function: Normalizing Constant Problem

Computing \(Z\) is intractable

\[Z(\boldsymbol{\theta}) = \int \exp(-E(\mathbf{x}; \boldsymbol{\theta})) d\mathbf{x}\]

What \(Z\) does:

  1. Normalizes probability: \(p(\mathbf{x}) = \frac{1}{Z} \exp(-E(\mathbf{x}))\)
  2. Generates moments: \(\langle f(\mathbf{x}) \rangle = \frac{\partial \log Z}{\partial \boldsymbol{\lambda}}\)
  3. Defines free energy: \(F = -T \log Z\)

Computational reality for images (224×224×3):

  • Dimension \(d = 150,528\)
  • Discrete case: \(Z = \sum_{x \in \{0,1\}^d} e^{-E(x)}\)
  • States to evaluate: \(2^{150,528}\)\(10^{45,000}\)
  • Time to compute: longer than universe age

Approximation strategies:

  • Variational bounds: \(\log Z \geq \mathbb{E}_q[\log p] + H[q]\)
  • Monte Carlo: \(Z \approx \frac{1}{N}\sum_i \frac{e^{-E(x_i)}}{q(x_i)}\)
  • Mean field: Factorize \(p(\mathbf{x}) \approx \prod_i p_i(x_i)\)

Free Energy and Derivatives

Helmholtz Free Energy \[F(\boldsymbol{\theta}) = -T \log Z(\boldsymbol{\theta})\]

Derivatives give expectations: \[\frac{\partial F}{\partial \theta_i} = \langle \frac{\partial E}{\partial \theta_i} \rangle_{p(\mathbf{x}|\boldsymbol{\theta})}\]

Second derivatives give covariances: \[\frac{\partial^2 F}{\partial \theta_i \partial \theta_j} = \text{Cov}\left[\frac{\partial E}{\partial \theta_i}, \frac{\partial E}{\partial \theta_j}\right]\]

Physics interpretation:

  • \(F\) minimized at equilibrium
  • \(\nabla_{\boldsymbol{\theta}} F = 0\) defines critical points
  • Hessian \(\nabla^2 F\) determines stability

Learning problem:

  • Maximum likelihood: minimize \(F\) w.r.t. parameters
  • Requires computing expectations under model distribution
  • Intractable for high dimensions!

Why \(F = -T \log Z\) connects physics and ML:

  • Entropy: \(S = U/T + \log Z\)
  • Free energy: \(F = U - TS = -T \log Z\)
  • ML learning minimizes \(F\) equals maximizing likelihood
  • Temperature \(T\) controls exploration vs exploitation

Learning by Maximum Likelihood

Gradient of Log-Likelihood

Objective: Given data \(\{\mathbf{x}_1, ..., \mathbf{x}_N\}\), maximize: \[\mathcal{L}(\boldsymbol{\theta}) = \frac{1}{N}\sum_{i=1}^N \log p(\mathbf{x}_i|\boldsymbol{\theta})\]

Gradient of log-likelihood: \[\nabla_{\boldsymbol{\theta}} \log p(\mathbf{x}|\boldsymbol{\theta}) = -\nabla_{\boldsymbol{\theta}} E(\mathbf{x};\boldsymbol{\theta}) + \mathbb{E}_{p(\mathbf{x}'|\boldsymbol{\theta})}[\nabla_{\boldsymbol{\theta}} E(\mathbf{x}';\boldsymbol{\theta})]\]

Two phases:

  • Positive phase: \(-\nabla_{\boldsymbol{\theta}} E(\mathbf{x};\boldsymbol{\theta})\) (push down energy at data)
  • Negative phase: \(\mathbb{E}_{p}[\nabla_{\boldsymbol{\theta}} E(\mathbf{x}';\boldsymbol{\theta})]\) (push up elsewhere)

Critical Challenge: Computing negative phase requires samples from current model \(p(\mathbf{x}|\boldsymbol{\theta})\)!

Learning = Energy Sculpting: \[\nabla_{\boldsymbol{\theta}} \log p(data) = \underbrace{-\nabla_{\boldsymbol{\theta}} E(data)}_{\text{push down}} + \underbrace{\mathbb{E}_{model}[\nabla_{\boldsymbol{\theta}} E]}_{\text{push up}}\]

Positive phase (easy) vs Negative phase (intractable expectation)

The Fundamental Learning Theorem for EBMs

Theorem: For any energy-based model, the log-likelihood gradient decomposes as:

\[\frac{\partial}{\partial \boldsymbol{\theta}} \log p(\mathbf{x}^{(n)}; \boldsymbol{\theta}) = -\frac{\partial E(\mathbf{x}^{(n)}; \boldsymbol{\theta})}{\partial \boldsymbol{\theta}} + \frac{\partial F(\boldsymbol{\theta})}{\partial \boldsymbol{\theta}}\]

Proof: Starting from \(p(\mathbf{x}; \boldsymbol{\theta}) = \frac{1}{Z(\boldsymbol{\theta})} e^{-E(\mathbf{x};\boldsymbol{\theta})}\):

\[\frac{\partial}{\partial \boldsymbol{\theta}} \log p(\mathbf{x}^{(n)}) = \frac{\partial}{\partial \boldsymbol{\theta}} [-E(\mathbf{x}^{(n)}) - \log Z(\boldsymbol{\theta})]\]

\[= -\frac{\partial E(\mathbf{x}^{(n)})}{\partial \boldsymbol{\theta}} - \frac{1}{Z} \frac{\partial Z}{\partial \boldsymbol{\theta}}\]

Since \(\frac{\partial Z}{\partial \boldsymbol{\theta}} = \int \frac{\partial}{\partial \boldsymbol{\theta}} e^{-E(\mathbf{x};\boldsymbol{\theta})} d\mathbf{x} = -\int e^{-E(\mathbf{x};\boldsymbol{\theta})} \frac{\partial E}{\partial \boldsymbol{\theta}} d\mathbf{x}\):

\[\frac{1}{Z} \frac{\partial Z}{\partial \boldsymbol{\theta}} = -\mathbb{E}_{p(\mathbf{x};\boldsymbol{\theta})}\left[\frac{\partial E}{\partial \boldsymbol{\theta}}\right] = \frac{\partial F}{\partial \boldsymbol{\theta}}\]

Fisher Information Connection: \(I(\boldsymbol{\theta}) = \text{Cov}\left[\frac{\partial E}{\partial \boldsymbol{\theta}}\right] = \frac{\partial^2 F}{\partial \boldsymbol{\theta}^2}\)

Why Modern ML Exists:

Positive phase: \(-\frac{\partial E(data)}{\partial \boldsymbol{\theta}}\)

  • Trivial to compute
  • Just evaluate at observed data

Negative phase: \(\frac{\partial F}{\partial \boldsymbol{\theta}}\)

  • Intractable expectation over model
  • Requires sampling from current distribution

Modern Solutions:

  • GANs: Avoid explicit densities → no partition function
  • VAEs: Tractable bounds → no exact negative phase
  • Diffusion: Score matching → only energy gradients
  • Contrastive: Approximate negative phase → short MCMC

Main Problem: All tractable learning requires avoiding the negative phase computation.

The Sampling Problem

Markov Chain Monte Carlo (MCMC)

To sample from \(p(\mathbf{x}) \propto \exp(-E(\mathbf{x}))\):

Langevin Dynamics: \[\mathbf{x}_{t+1} = \mathbf{x}_t - \frac{\epsilon}{2}\nabla_{\mathbf{x}} E(\mathbf{x}_t) + \sqrt{\epsilon}\boldsymbol{\eta}_t\] where \(\boldsymbol{\eta}_t \sim \mathcal{N}(0, \mathbf{I})\)

Hamiltonian Monte Carlo:

  • Auxiliary momentum: \(\mathbf{p} \sim \mathcal{N}(0, \mathbf{M})\)
  • Simulate: \(H(\mathbf{x}, \mathbf{p}) = E(\mathbf{x}) + \frac{1}{2}\mathbf{p}^T\mathbf{M}^{-1}\mathbf{p}\)
  • Accept/reject with Metropolis

Why MCMC fails at scale:

  • Well-separated modes → exponential mixing time
  • High dimensions → slow exploration
  • ImageNet scale: computationally impossible

Gradient Computation in Practice

Full gradient requires equilibrium samples: \[\nabla_{\boldsymbol{\theta}} \mathcal{L} = \underbrace{\frac{1}{N}\sum_{i=1}^N \nabla_{\boldsymbol{\theta}} E(\mathbf{x}_i)}_{\text{Data term (easy)}} - \underbrace{\mathbb{E}_{p_{\boldsymbol{\theta}}}[\nabla_{\boldsymbol{\theta}} E(\mathbf{x})]}_{\text{Model term (hard!)}}\]

Computational cost per gradient step:

  1. Run MCMC to equilibrium: \(O(T_{\text{mix}} \times d)\)
  2. Collect K samples for Monte Carlo estimate
  3. Compute energy gradients: \(O(K \times \text{forward pass})\)

For images (\(d = 150\)K):

  • \(T_{\text{mix}} \approx 10^6\) steps (optimistic!)
  • Time per gradient: hours to days
  • SGD needs 10K+ gradients → years!

Approximation strategies emerged from this problem:

  • Contrastive Divergence: use short chains
  • Score matching: avoid sampling entirely
  • Variational methods: tractable bounds

Restricted Boltzmann Machines (RBM)

RBM Architecture

Bipartite Structure

Visible units \(\mathbf{v} \in \{0,1\}^D\), Hidden units \(\mathbf{h} \in \{0,1\}^H\)

Energy function: \[E(\mathbf{v}, \mathbf{h}) = -\mathbf{v}^T \mathbf{W} \mathbf{h} - \mathbf{b}^T \mathbf{v} - \mathbf{c}^T \mathbf{h}\]

Conditional independence: \[p(\mathbf{h}|\mathbf{v}) = \prod_{j=1}^H p(h_j|\mathbf{v})\] \[p(\mathbf{v}|\mathbf{h}) = \prod_{i=1}^D p(v_i|\mathbf{h})\]

where: \[p(h_j = 1|\mathbf{v}) = \sigma(\mathbf{W}_{:,j}^T \mathbf{v} + c_j)\] \[p(v_i = 1|\mathbf{h}) = \sigma(\mathbf{W}_{i,:} \mathbf{h} + b_i)\]

Why RBMs work: Tractable Gibbs sampling

  • Sample \(\mathbf{h} \sim p(\mathbf{h}|\mathbf{v})\) in parallel (factorizes)
  • Sample \(\mathbf{v} \sim p(\mathbf{v}|\mathbf{h})\) in parallel (factorizes)
  • No within-layer connections → faster mixing than general models

Contrastive Divergence

The CD-k Algorithm

Instead of running chain to equilibrium, use k steps:

  1. Start from data: \(\mathbf{v}^{(0)} = \mathbf{x}_{\text{data}}\)
  2. Sample \(\mathbf{h}^{(0)} \sim p(\mathbf{h}|\mathbf{v}^{(0)})\)
  3. Sample \(\mathbf{v}^{(1)} \sim p(\mathbf{v}|\mathbf{h}^{(0)})\)
  4. Repeat k times…
  5. Approximate: \(\mathbb{E}_{p_{\text{model}}} \approx \mathbb{E}_{\text{CD-k}}\)

Gradient approximation: \[\Delta \mathbf{W} \approx \langle \mathbf{v}\mathbf{h}^T \rangle_{\text{data}} - \langle \mathbf{v}\mathbf{h}^T \rangle_{\text{CD-k}}\]

CD-1 often sufficient! (k=1)

What CD actually does:

  • Doesn’t minimize KL(p_data || p_model) directly
  • Minimizes difference of KL divergences
  • Biases learning towards data manifold
  • Poor at exploring between modes

CD Gradient Dynamics

What CD actually optimizes:

Not KL(p_data || p_model), but: \[\text{CD}_k = \text{KL}(p_{\text{data}} || p_{\text{model}}) - \text{KL}(p_k || p_{\text{model}})\]

where \(p_k\) = distribution after \(k\) Gibbs steps from data

Partition function cancellation:

\[\frac{\partial \text{CD}_k}{\partial \boldsymbol{\theta}} = \mathbb{E}_{p_{\text{data}}}[\nabla_{\boldsymbol{\theta}} E] - \mathbb{E}_{p_k}[\nabla_{\boldsymbol{\theta}} E] + \underbrace{(\nabla_{\boldsymbol{\theta}} \log Z - \nabla_{\boldsymbol{\theta}} \log Z)}_{=0}\]

The problematic \(\nabla_{\boldsymbol{\theta}} \log Z(\boldsymbol{\theta})\) terms cancel.

Tractable gradient computation without equilibrium sampling!

Implications:

  • \(p_k\) stays close to \(p_{\text{data}}\) initially
  • Model learns data manifold quickly
  • Between-mode regions poorly modeled
  • Can create spurious modes

Persistent CD: Continue chains across minibatches

  • Reduces bias
  • Better negative samples
  • More computation

Connection to score matching:

  • CD-1 gradient ≈ score matching gradient at data points
  • Explains why CD-1 works despite severe bias
  • Led to development of denoising score matching

Modern Energy Models

Score Matching

Avoid the partition function entirely!

Score function: \(\mathbf{s}(\mathbf{x}; \boldsymbol{\theta}) = \nabla_{\mathbf{x}} \log p(\mathbf{x}; \boldsymbol{\theta})\)

For EBM: \(\mathbf{s}(\mathbf{x}) = -\nabla_{\mathbf{x}} E(\mathbf{x}; \boldsymbol{\theta})\) (no \(Z\)!)

Score Matching Objective: \[J(\boldsymbol{\theta}) = \frac{1}{2}\mathbb{E}_{p_{\text{data}}}[||\mathbf{s}(\mathbf{x}; \boldsymbol{\theta}) - \nabla_{\mathbf{x}} \log p_{\text{data}}(\mathbf{x})||^2]\]

Problem: Don’t know \(\nabla_{\mathbf{x}} \log p_{\text{data}}\)!

Solution (Integration by parts): \[J(\boldsymbol{\theta}) = \mathbb{E}_{p_{\text{data}}}[\text{tr}(\nabla_{\mathbf{x}} \mathbf{s}(\mathbf{x}; \boldsymbol{\theta})) + \frac{1}{2}||\mathbf{s}(\mathbf{x}; \boldsymbol{\theta})||^2] + C\]

Why score matching works: No sampling required!

  • Direct optimization on data distribution
  • No MCMC needed
  • Scales to high dimensions

Denoising Score Matching

Practical score matching via denoising

Perturb data: \(\tilde{\mathbf{x}} = \mathbf{x} + \boldsymbol{\epsilon}\), where \(\boldsymbol{\epsilon} \sim \mathcal{N}(0, \sigma^2 \mathbf{I})\)

Result (Vincent, 2011): \[\mathbb{E}_{\tilde{\mathbf{x}}}[||\mathbf{s}(\tilde{\mathbf{x}}) - \nabla_{\tilde{\mathbf{x}}} \log p(\tilde{\mathbf{x}}|\mathbf{x})||^2]\] \[= \mathbb{E}_{\tilde{\mathbf{x}}}[||\mathbf{s}(\tilde{\mathbf{x}}) + \frac{\tilde{\mathbf{x}} - \mathbf{x}}{\sigma^2}||^2] + C\]

Denoising Autoencoder Connection:

  • Train network to predict: \(\mathbf{x} - \tilde{\mathbf{x}}\)
  • Equivalent to learning score function!
  • Network output = \(\sigma^2 \mathbf{s}(\tilde{\mathbf{x}})\)

Connection to diffusion models:

  • Multiple noise levels → annealed score matching
  • Learn \(\mathbf{s}(\mathbf{x}, t)\) for different noise scales
  • Basis for DDPM and score-based models

Deep Energy-Based Models

Neural Networks as Energy Functions

Energy: \(E(\mathbf{x}; \boldsymbol{\theta}) = ||f_{\boldsymbol{\theta}}(\mathbf{x})||^2\)

Or more generally: \[E(\mathbf{x}; \boldsymbol{\theta}) = -\log \sum_y \exp(f_{\boldsymbol{\theta}}(\mathbf{x}, y))\]

Training with Langevin Dynamics:

# Sampling loop (inner)
x = x_init
for t in range(T):
    x = x - λ * grad_x(E(x)) + sqrt(2λ) * randn()

# Parameter update (outer)
grad_theta = grad_E(x_data) - grad_E(x_sample)
theta = theta - lr * grad_theta

Memory Requirements:

  • Store replay buffer of samples
  • Persistent chains across updates
  • 10-100× more memory than discriminative

Problems with deep EBMs:

  • Mode coverage remains poor
  • Training instability at high capacity
  • Sampling still expensive (T~1000 steps)
  • Evaluation metrics problematic

EBM Limitations and Legacy

Why EBMs Failed at Scale

Computational Reality

ImageNet image: 224 × 224 × 3 = 150,528 dims

Per gradient step:

  1. Sample from model: ~10,000 Langevin steps
  2. Each step: Forward + backward pass
  3. Total: 20,000 network evaluations
  4. Time: ~5 minutes on V100

Per epoch (1.2M images):

  • Discriminative model: 2 hours
  • EBM: 2 hours × 10,000 = 2.3 years

Memory explosion:

  • Replay buffer: 10,000 samples × 150KB = 1.5GB
  • Persistent chains: Another 1.5GB
  • Gradients during sampling: 4× model size

Fundamental problem beyond computation:

  • High-dimensional spaces are mostly empty
  • Modes separated by vast low-probability regions
  • Local moves can’t explore efficiently

What Survived from EBMs

Score Matching → Diffusion Models

  • Denoising score matching at multiple scales
  • DDPM, Score-based models
  • State-of-the-art generation quality

Contrastive Learning → Self-Supervised

  • InfoNCE loss is contrastive divergence
  • SimCLR, MoCo use energy concepts
  • CLIP: joint energy over modalities

Energy Functions → Implicit Models

  • GANs: discriminator as energy difference
  • Flow models: tractable energy via invertibility
  • VAEs: variational bound on log-likelihood

Lessons learned:

  • Explicit density modeling is hard
  • Implicit methods can sidestep intractability
  • Local approximations (CD) inspire global solutions
  • Score functions more tractable than densities

Legacy → Modern ML:

  • EBM partition functions → VAE tractable bounds
  • EBM energy functions → GAN implicit densities
  • EBM score functions → Diffusion models

EBMs beget GANs

EBMs: Model \(p(\mathbf{x})\) explicitly

  • Requires normalizing constant \(Z\)
  • Bottleneck: Sampling from model
  • Result: Computationally intractable

GANs: Generate samples directly

  • Requires only a discriminator
  • Bottleneck: Training stability
  • Result: Fast, high-quality samples

What changes: Replace intractable sampling with adversarial training

Instead of: \(\mathbb{E}_{p_{\text{model}}}[f(\mathbf{x})]\) Use: Discriminator to estimate density ratios

GANs trade computational intractability for training instability:

  • No partition function needed
  • Direct sample generation
  • Training instability becomes main challenge
  • Mode collapse replaces mixing problems

Generative Adversarial Networks (GANs)

What GANs Create

This Does Not Exist

Generated samples from modern GANs (2024):

  • thispersondoesnotexist.com (StyleGAN2)
  • thischemicaldoesnotexist.com (MolGAN)
  • thisrentaldoesnotexist.com (HouseGAN)
  • thismusicvideodoesnotexist.com

Generated faces from StyleGAN2

Quality metrics:

  • Resolution: up to 1024×1024
  • FID score: < 3.0 on FFHQ
  • Training: 7 days on 8 V100s
  • Parameters: 30M (StyleGAN2)

Various GAN applications

Computational requirements:

  • Inference: 50ms per image (V100)
  • Memory: 8GB for generation
  • Training data: 70k images minimum

Original GAN Results (2014)

Goodfellow et al. inaugural results:

Original GAN samples MNIST/TFD

Architecture (fully connected):

  • Generator: 100 → 1200 → 1200 → 784
  • Discriminator: 784 → 240 → 240 → 1
  • Activation: ReLU (G), Maxout (D)
  • No batch norm (not invented yet)

Quantitative results (MNIST):

  • Log-likelihood: -325 ± 35 (Parzen estimate)
  • Training time: 2 hours (GTX 580)
  • Mode coverage: ~70% of digits

Historical impact:

  • First implicit generative model
  • No Markov chains or inference networks
  • Spawned 50,000+ papers (2014-2024)
  • Led to StyleGAN, DALL-E 2, Stable Diffusion

Minimax Objective

KL Divergence Review

Kullback-Leibler divergence

\[\text{KL}(p||q) = \mathbb{E}_{x \sim p}\left[\log \frac{p(x)}{q(x)}\right]\]

Properties:

  • \(\text{KL}(p||q) \geq 0\) (equality iff \(p = q\))
  • Not symmetric: \(\text{KL}(p||q) \neq \text{KL}(q||p)\)
  • Not a metric (no triangle inequality)

Forward KL \((p||q)\): Mean-seeking

  • \(q\) covers all of \(p\)
  • \(q\) spreads to avoid infinite penalty

Reverse KL \((q||p)\): Mode-seeking

  • \(q\) concentrates on modes of \(p\)
  • \(q\) can ignore parts of \(p\)

GAN behavior matches reverse KL:

  • Generator picks modes rather than covering all
  • Mode collapse is optimal given objective
  • Forward KL would prevent this but intractable

The Adversarial Objective

Minimax game between two networks:

\[\min_G \max_D V(D,G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[\log D(\mathbf{x})] + \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})}[\log(1-D(G(\mathbf{z})))]\]

Interpretation as binary classification:

Discriminator \(D\) solves:

  • Class 1: Real data \(\mathbf{x} \sim p_{\text{data}}\)
  • Class 0: Fake data \(G(\mathbf{z})\), \(\mathbf{z} \sim p(\mathbf{z})\)
  • Output: \(D(\mathbf{x}) = P(\mathbf{x} \text{ is real})\)

Generator \(G\):

  • Produces samples \(G(\mathbf{z})\) to fool \(D\)
  • No direct access to real data
  • Only gradient signal from \(D\)

Optimal Discriminator: Complete Derivation

Step-by-step derivation for fixed \(G\):

Starting from: \[V(D,G) = \int_{\mathbf{x}} p_{\text{data}}(\mathbf{x})\log D(\mathbf{x}) + p_g(\mathbf{x})\log(1-D(\mathbf{x})) d\mathbf{x}\]

For any \(\mathbf{x}\), maximize integrand: \[f(y) = a \log(y) + b \log(1-y)\] where \(a = p_{\text{data}}(\mathbf{x})\), \(b = p_g(\mathbf{x})\), \(y = D(\mathbf{x})\)

Taking derivative: \[\frac{df}{dy} = \frac{a}{y} - \frac{b}{1-y}\]

Setting to zero and solving: \[\frac{a}{y} = \frac{b}{1-y} \Rightarrow a(1-y) = by\] \[a - ay = by \Rightarrow a = y(a+b)\]

Therefore: \[D^*(\mathbf{x}) = \frac{p_{\text{data}}(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})}\]

Verification: Second derivative \(\frac{d^2f}{dy^2} = -\frac{a}{y^2} - \frac{b}{(1-y)^2} < 0\) confirms maximum.

Substitution Gives JS Divergence

With \(D = D^*\), generator objective becomes:

\[C(G) = \max_D V(D,G) = V(D^*, G)\]

Substituting \(D^*\): \[C(G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}\left[\log \frac{p_{\text{data}}(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})}\right]\] \[+ \mathbb{E}_{\mathbf{x} \sim p_g}\left[\log \frac{p_g(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})}\right]\]

This equals: \[C(G) = -\log 4 + 2 \cdot \text{JS}(p_{\text{data}} || p_g)\]

where Jensen-Shannon divergence: \[\text{JS}(p||q) = \frac{1}{2}\text{KL}(p||m) + \frac{1}{2}\text{KL}(q||m)\] \[m = \frac{1}{2}(p + q)\]

Properties of JS divergence:

  • Symmetric: \(\text{JS}(p||q) = \text{JS}(q||p)\)
  • Bounded: \(0 \leq \text{JS}(p||q) \leq \log 2\)
  • \(\text{JS} = 0\) iff \(p = q\)
  • Smooth even when supports don’t overlap

KL vs JS: Mode Coverage Trade-offs

Forward KL: \(D_{KL}(p_{data}||p_g)\) \[\mathbb{E}_{x \sim p_{data}}\left[\log \frac{p_{data}(x)}{p_g(x)}\right]\]

  • Penalty when \(p_{data} > 0, p_g = 0\):
  • Penalty when \(p_{data} = 0, p_g > 0\): 0
  • Result: Generator covers all modes (blurry)

Reverse KL: \(D_{KL}(p_g||p_{data})\)

  • Penalty when \(p_g > 0, p_{data} = 0\):
  • Penalty when \(p_g = 0, p_{data} > 0\): 0
  • Result: Generator picks few modes (sharp)

JS Divergence (GAN objective): \[JS(p||q) = \frac{1}{2}KL(p||m) + \frac{1}{2}KL(q||m)\]

  • Symmetric compromise
  • Bounded: \(0 \leq JS \leq \log 2\)
  • Sharp samples, moderate mode coverage

GAN Optimization and Instability

Gradient Flow Analysis

Generator gradient: \[\nabla_{\theta_G} V = -\mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})}\left[\nabla_{\theta_G} \log(1-D(G(\mathbf{z};\theta_G)))\right]\]

Expanding: \[\nabla_{\theta_G} V = \mathbb{E}_{\mathbf{z}}\left[\frac{1}{1-D(G(\mathbf{z}))} \cdot \nabla_{\theta_G} D(G(\mathbf{z}))\right]\]

Problem when D is too good:

  • Early training: \(D(G(\mathbf{z})) \approx 0\) (perfect discrimination)
  • Gradient: \(\frac{1}{1-0} = 1\) (weak signal)
  • But if \(D(G(\mathbf{z})) = 0.01\): gradient stays ≈ 1
  • Generator learns very slowly!

Vanishing gradients when \(D\) too good:

  • Generator gradient → 0 as \(D\) → perfect
  • Training stalls early
  • No learning signal for \(G\)

Non-Saturating Objective

Modified generator objective:

Instead of: \(\min_G \mathbb{E}_{\mathbf{z}}[\log(1-D(G(\mathbf{z})))]\)

Use: \(\max_G \mathbb{E}_{\mathbf{z}}[\log D(G(\mathbf{z}))]\)

Same optimum, different dynamics: \[\nabla_{\theta_G} = \mathbb{E}_{\mathbf{z}}\left[\frac{1}{D(G(\mathbf{z}))} \cdot \nabla_{\theta_G} D(G(\mathbf{z}))\right]\]

When \(D(G(\mathbf{z})) \approx 0\):

  • Original gradient: ≈ -1 (weak)
  • Non-saturating gradient: ≈ 1/0.01 = 100 (strong!)

Changes gradient dynamics completely

  • Strong signal when G is losing
  • Prevents early training collapse
  • Standard practice in all implementations

Mode Collapse Mechanics

Reverse KL behavior in GANs:

Generator minimizes (approximately): \[\text{KL}(p_g || p_{\text{data}})\]

Properties of reverse KL:

  • Zero when p_g = 0 but p_data > 0 (can ignore modes!)
  • Infinite when p_g > 0 but p_data = 0 (avoid generating outside data)
  • Prefers to match single mode perfectly

Sequential mode hopping:

  1. \(G\) focuses on one mode
  2. \(D\) learns to reject that mode
  3. \(G\) jumps to different mode
  4. Cycle repeats

Why mode collapse happens:

  • No penalty for missing modes (reverse KL = 0)
  • Easier to fool D with perfect single mode
  • Optimal for generator given current D

GAN Training Computational Cost

Per iteration costs:

For batch size B, image size H×W×C:

  • Generator forward: \(O(B \times \text{params}_G)\)
  • Discriminator forward: \(O(2B \times \text{params}_D)\)
    • Once for real data
    • Once for fake data
  • Backward passes: Similar cost to forward

Typical architecture sizes:

  • DCGAN (64×64): G=3.5M, D=2.8M params
  • StyleGAN2 (1024×1024): G=30M, D=24M params
  • Memory: ~4× model size (gradients, momentum)

Training time comparison:

  • DCGAN on CIFAR-10: 2-6 hours (single GPU)
  • StyleGAN2 on FFHQ: 7-10 days (8 V100s)
  • BigGAN on ImageNet: 2-4 days (256 TPUs)

Cost vs other methods:

Method Forward/iter Memory Training stability
VAE 2× params Stable
GAN 4× params Unstable
EBM (CD-1) 5× params Biased
EBM (Full) 1000× 10× params Accurate

D/G update ratio affects cost:

  • Standard: 1:1 updates
  • WGAN: 5:1 (D:G) → 5× discriminator cost
  • Progressive GAN: Dynamic ratio

Computational bottlenecks:

  • High-res generation: Memory limited
  • Batch size critical for stability (≥32)
  • FP16 training: 2× speedup, stability issues

Trade-off: GANs are 3× slower than VAEs per iteration but produce sharper samples

Nash Equilibrium Analysis

Nash equilibrium exists:

  • p_g = p_data
  • D(x) = 1/2 everywhere
  • Neither player can improve unilaterally

But not unique and not stable:

  • Multiple equilibria possible
  • Gradient dynamics don’t converge
  • Training finds limit cycles

Actual training dynamics:

  • Oscillations around equilibrium
  • Never settles to single point
  • D and G chase each other

Theoretical result (Goodfellow et al.): If G and D have enough capacity and updates are small:

  • Algorithm converges to p_g = p_data
  • Reality: assumptions rarely hold

What happens in practice:

  • Training doesn’t converge to equilibrium
  • Oscillations continue indefinitely
  • Stop based on sample quality metrics
  • Moving average of G weights helps stability

Practical Training Stabilization

Training Tricks That Work

Label Smoothing (Salimans et al. 2016):

# Instead of hard labels 0 and 1
real_labels = torch.ones(batch_size) 
fake_labels = torch.zeros(batch_size)

# Use soft labels
real_labels = 0.7 + 0.3 * torch.rand(batch_size)  # [0.7, 1.0]
fake_labels = 0.0 + 0.3 * torch.rand(batch_size)  # [0.0, 0.3]

Impact: 15% reduction in mode collapse frequency

Optimizer Configuration:

# Discriminator: SGD with momentum
opt_D = torch.optim.SGD(D.parameters(), 
                        lr=0.0002, momentum=0.9)

# Generator: Adam for stability
opt_G = torch.optim.Adam(G.parameters(), 
                         lr=0.0001, betas=(0.5, 0.999))

Update Ratio:

  • Standard GAN: 1:1 (D:G)
  • WGAN: 5:1 (critical for Lipschitz)
  • When D loss < 0.5: increase D updates
  • When D loss > 0.8: increase G updates

Batch Size Impact:

Batch Size Training Time FID Score Stability
8 48h 45.2 Poor
32 24h 28.4 Moderate
64 18h 22.1 Good
128 12h 20.3 Best
256 10h 21.8 Good*

*Diminishing returns, memory limited

Gradient Penalties:

# Gradient clipping (basic)
torch.nn.utils.clip_grad_norm_(
    G.parameters(), max_norm=10.0)

# Spectral normalization (better)
D = SpectralNorm(D)  # σ(W) = 1

# R1 regularization (StyleGAN)
grad = autograd.grad(d_real, real_img)[0]
r1_loss = (grad ** 2).sum() / 2

Memory overhead: +25% for gradient penalties

Initialization and Normalization

Weight Initialization:

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        # He initialization for ReLU
        nn.init.kaiming_normal_(m.weight, mode='fan_out')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.ConvTranspose2d):
        # Xavier for generator
        nn.init.xavier_normal_(m.weight)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

model.apply(init_weights)

Normalization Strategies:

Network Normalization Location Why
Generator BatchNorm All except output Stabilizes gradients
Discriminator None/LayerNorm Optional BN causes correlation
Both SpectralNorm All layers Enforces Lipschitz

Impact on training:

  • Without init: 60% fail in first 1k iterations
  • With proper init: <5% early failure
  • Spectral norm: 2× longer stable training

Debugging Failed Training

Common Failure Modes and Fixes:

1. Mode Collapse

# Symptoms: D loss → 0, G produces identical outputs
# Check: sample diversity
diversity = torch.std(generated_batch, dim=0).mean()
if diversity < threshold:
    # Fixes:
    # - Reduce learning rate
    # - Add noise to inputs
    # - Increase batch size
    # - Use unrolled GANs

2. Vanishing Gradients

# Monitor gradient norms
for name, param in model.named_parameters():
    if param.grad is not None:
        grad_norm = param.grad.norm().item()
        if grad_norm < 1e-5:
            print(f"Vanishing gradient in {name}")
# Fix: Use WGAN-GP or non-saturating loss

3. Oscillation

# Track losses over window
loss_history = []
if len(loss_history) > 100:
    variance = np.var(loss_history[-100:])
    if variance > threshold:
        # Reduce learning rates
        scheduler.step()

Diagnostic Metrics:

@torch.no_grad()
def diagnose_gan(G, D, dataloader):
    metrics = {}
    
    # 1. Gradient health
    metrics['d_grad_norm'] = compute_grad_norm(D)
    metrics['g_grad_norm'] = compute_grad_norm(G)
    
    # 2. Mode coverage
    fake_samples = []
    for _ in range(10):
        z = torch.randn(100, latent_dim)
        fake = G(z)
        fake_samples.append(fake)
    
    # Inter-batch diversity
    diversity = compute_diversity(fake_samples)
    metrics['diversity'] = diversity
    
    # 3. Discriminator confidence
    real_scores = []
    fake_scores = []
    for real_batch in dataloader:
        real_score = D(real_batch).mean()
        fake_score = D(G(z)).mean()
        real_scores.append(real_score)
        fake_scores.append(fake_score)
    
    metrics['d_real_acc'] = (real_scores > 0.5).mean()
    metrics['d_fake_acc'] = (fake_scores < 0.5).mean()
    
    return metrics

Warning signs:

  • D accuracy > 0.99 or < 0.4
  • Gradient norm < 1e-5 or > 100
  • Diversity score < 0.1

Production Training Pipeline

Complete Training Recipe:

class GANTrainer:
    def __init__(self, G, D, config):
        self.G = G
        self.D = D
        
        # Optimizers with different LR
        self.opt_G = Adam(G.parameters(), 
                         lr=config.lr_g, betas=(0.5, 0.999))
        self.opt_D = Adam(D.parameters(),
                         lr=config.lr_d, betas=(0.5, 0.999))
        
        # Learning rate scheduling
        self.scheduler_G = ExponentialLR(self.opt_G, gamma=0.99)
        self.scheduler_D = ExponentialLR(self.opt_D, gamma=0.99)
        
        # Gradient penalty weight
        self.lambda_gp = config.lambda_gp
        
        # Label smoothing
        self.real_label = 0.9
        self.fake_label = 0.1
        
    def train_step(self, real_batch):
        batch_size = real_batch.size(0)
        
        # Train Discriminator
        for _ in range(self.d_steps):
            self.opt_D.zero_grad()
            
            # Real samples
            real_validity = self.D(real_batch)
            real_labels = torch.full_like(real_validity, 
                                         self.real_label)
            real_labels += torch.rand_like(real_labels) * 0.1
            
            # Fake samples  
            z = torch.randn(batch_size, self.latent_dim)
            fake = self.G(z).detach()
            fake_validity = self.D(fake)
            fake_labels = torch.full_like(fake_validity,
                                         self.fake_label)
            fake_labels += torch.rand_like(fake_labels) * 0.1
            
            # Gradient penalty
            gp = self.gradient_penalty(real_batch, fake)
            
            # Total loss
            d_loss = (
                F.binary_cross_entropy(real_validity, real_labels) +
                F.binary_cross_entropy(fake_validity, fake_labels) +
                self.lambda_gp * gp
            )
            
            d_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.D.parameters(), 5.0)
            self.opt_D.step()
        
        # Train Generator
        self.opt_G.zero_grad()
        z = torch.randn(batch_size, self.latent_dim)
        fake = self.G(z)
        fake_validity = self.D(fake)
        
        # Generator wants D to output 1 for fake
        g_loss = F.binary_cross_entropy(fake_validity, 
                                       torch.ones_like(fake_validity))
        
        g_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.G.parameters(), 5.0)
        self.opt_G.step()
        
        return {'d_loss': d_loss.item(), 'g_loss': g_loss.item()}

Training Configuration:

# config.yaml
training:
  epochs: 200
  batch_size: 64
  
  # Learning rates
  lr_g: 0.0001
  lr_d: 0.0002
  
  # Update frequency
  d_steps: 1  # 5 for WGAN
  g_steps: 1
  
  # Regularization
  lambda_gp: 10  # WGAN-GP
  label_smoothing: 0.1
  
  # Stability
  gradient_clip: 5.0
  spectral_norm: true
  
  # Data augmentation
  augmentation:
    horizontal_flip: 0.5
    color_jitter: 0.1
    
  # Checkpointing
  save_every: 1000
  validate_every: 500
  
  # Early stopping
  patience: 10000
  min_fid: 20.0

Hardware Requirements:

Model GPU Memory Training Time Batch Size
DCGAN 64×64 4GB 12h 128
StyleGAN 256×256 16GB 3d 32
BigGAN 128×128 32GB 7d 256
StyleGAN2 1024×1024 48GB 14d 8

Guidance:

  • Start with small resolution
  • Use mixed precision training
  • Monitor metrics frequently
  • Save checkpoints regularly

Wasserstein GAN

Earth Mover’s Distance

Wasserstein-1 distance (Earth Mover’s):

\[W(p,q) = \inf_{\gamma \in \Pi(p,q)} \mathbb{E}_{(x,y) \sim \gamma}[||x - y||]\]

where \(\Pi(p,q)\) = set of all joint distributions with marginals p and q

Intuition:

  • Minimum cost to transform distribution p into q
  • Cost = amount of “mass” × distance moved
  • Unlike JS/KL: defined even when supports don’t overlap

Example:

  • p = δ(x), q = δ(x + α)
  • JS(p,q) = log 2 (maximum!) for any α ≠ 0
  • W(p,q) = α (smooth with distance)

Advantages of Wasserstein distance:

  • Provides gradients even without distribution overlap
  • Distance correlates with sample quality
  • Continuous measure prevents mode jumping

Kantorovich-Rubinstein Duality

Dual formulation:

\[W(p,q) = \sup_{||f||_L \leq 1} \mathbb{E}_{x \sim p}[f(x)] - \mathbb{E}_{y \sim q}[f(y)]\]

where \(||f||_L \leq 1\) means f is 1-Lipschitz: \[|f(x_1) - f(x_2)| \leq |x_1 - x_2|\]

Why this helps:

  • Don’t need to find optimal transport plan γ
  • Can parameterize f with neural network
  • Optimization over functions, not distributions

Connection to discriminator:

  • f plays role similar to discriminator
  • But outputs real values, not probabilities
  • Must satisfy Lipschitz constraint

WGAN Objective

WGAN formulation:

\[\max_{D} \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[D(\mathbf{x})] - \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})}[D(G(\mathbf{z}))]\]

subject to: D is 1-Lipschitz

No more sigmoids:

  • D outputs real values (critic, not discriminator)
  • No saturation problems
  • Wasserstein distance = objective value

Training algorithm:

for iteration in range(num_iterations):
    # Train critic multiple times
    for _ in range(n_critic):  # typically 5
        x_real = sample_batch(data)
        z = sample_noise(batch_size)
        x_fake = G(z)
        
        d_loss = -D(x_real).mean() + D(x_fake).mean()
        update_D(d_loss)
        enforce_lipschitz(D)  # Weight clipping or gradient penalty
    
    # Train generator
    z = sample_noise(batch_size)
    g_loss = -D(G(z)).mean()
    update_G(g_loss)

Why Lipschitz constraint matters:

  • Without it, D can arbitrarily increase outputs
  • Gradients explode → training fails
  • Ensures Wasserstein distance is properly estimated

Weight Clipping Problems

Original WGAN: Enforce Lipschitz via weight clipping

# After each gradient update
for param in D.parameters():
    param.data.clamp_(-0.01, 0.01)

Problems with weight clipping:

  1. Capacity underuse:

    • Most weights pushed to ±c
    • Network becomes less expressive
    • Binary weight distribution
  2. Gradient issues:

    • Vanishing gradients if c too small
    • Exploding gradients if c too large
    • Sensitive to c value
  3. Optimization difficulty:

    • Weights stuck at boundaries
    • Hard to learn complex functions

Weight distribution after clipping:

  • Concentrated at ±c boundaries
  • Lost most information
  • Network essentially linear

Gradient Penalty (WGAN-GP)

Better Lipschitz enforcement:

Instead of clipping, add penalty: \[L = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[D(\mathbf{x})] - \mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})}[D(G(\mathbf{z}))]\] \[+ \lambda \mathbb{E}_{\hat{\mathbf{x}} \sim p_{\hat{\mathbf{x}}}}[(||\nabla_{\hat{\mathbf{x}}} D(\hat{\mathbf{x}})||_2 - 1)^2]\]

Sample \(\hat{\mathbf{x}}\) along interpolations: \[\hat{\mathbf{x}} = \epsilon \mathbf{x}_{\text{real}} + (1-\epsilon) \mathbf{x}_{\text{fake}}\] where \(\epsilon \sim U[0,1]\)

Why ||∇D|| = 1:

  • 1-Lipschitz function has gradient norm ≤ 1
  • Optimal critic has ||∇D|| = 1 almost everywhere
  • Penalty encourages this directly

Implementation:

# Interpolate between real and fake
eps = torch.rand(batch_size, 1, 1, 1)
x_hat = eps * x_real + (1-eps) * x_fake
x_hat.requires_grad = True

# Compute gradient penalty
d_hat = discriminator(x_hat)
grad = autograd.grad(d_hat, x_hat, 
                     create_graph=True)[0]
grad_norm = grad.view(batch_size, -1).norm(2, dim=1)
gp = lambda * ((grad_norm - 1) ** 2).mean()

Computational cost:

  • 3× backward passes per iteration
  • 25-50% slower than vanilla GAN
  • λ = 10 typical value

WGAN-GP advantages:

  • No capacity underuse
  • Stable training
  • Better sample quality
  • No hyperparameter tuning (λ=10 works universally)

GAN Architectures

Transposed Convolution Mechanics

Upsampling in Generator Networks:

Standard convolution (stride=1, 3×3 kernel):

  • Input: H × W → Output: H × W
  • Parameters: K² × C_in × C_out

Transposed convolution (stride=2, 3×3 kernel):

  • Input: H × W → Output: 2H × 2W
  • Parameters: K² × C_in × C_out (same!)

Matrix perspective:

  • Conv: y = Wx (W is sparse Toeplitz)
  • TransposeConv: y = W^T x (transpose of conv matrix)
  • Not inverse: W^T W ≠ I

Computational cost:

  • Same FLOPs as equivalent convolution
  • Memory: 4× output due to larger spatial dims

Implementation (PyTorch):

# Upsampling from 4×4 to 8×8
nn.ConvTranspose2d(
    in_channels=512,
    out_channels=256, 
    kernel_size=4,
    stride=2,
    padding=1
)
# Output: [B, 256, 8, 8]
# Parameters: 4×4×512×256 = 2,097,152

Checkerboard artifacts:

  • Overlap when stride < kernel_size
  • Solution: stride=kernel_size or resize+conv

DCGAN Principles

Deep Convolutional GAN (2015)

  1. All-convolutional nets

    • Replace pooling with strided convolutions
    • Replace fully connected with global pooling
  2. Batch normalization

    • In G: all layers except output
    • In D: all layers except input
    • Stabilizes training dramatically
  3. Activation functions

    • G: ReLU hidden, Tanh output
    • D: LeakyReLU throughout (α=0.2)
  4. No fully connected hidden layers

    • Only convolutions
    • Reduces parameters from 50M to 3M

Computational requirements:

  • Parameters: G: 3.5M, D: 2.8M
  • Memory (batch=128): ~4GB GPU RAM
  • Training: 2-3 days on single GPU for 64×64

Impact: First architecture to reliably generate sharp images at 64×64

Conditional GAN (cGAN)

Control generation with conditions

Standard GAN: G(z) → x

Conditional GAN: G(z, y) → x

where y can be:

  • Class label (one-hot vector)
  • Text embedding
  • Another image
  • Any auxiliary information

Modified objective: \[\min_G \max_D V(D,G) = \mathbb{E}_{x,y}[\log D(x|y)]\] \[+ \mathbb{E}_{z,y}[\log(1-D(G(z|y)|y))]\]

Both G and D see the condition y

Conditioning Implementation

How to inject conditions:

1. Concatenation (simplest)

# Generator
x = concat([z, y], dim=1)
x = linear(x, 4*4*1024)

# Discriminator  
x = concat([image, y_spatial], dim=1)
x = conv2d(x, 64)

2. Projection Discriminator (better)

# Inner product of embedding and features
h = conv_layers(x)  # → features
y_emb = embed(y)    # → embedding
score = h @ y_emb.T + bias

3. Adaptive Instance Norm (StyleGAN)

  • Learn scale γ(y) and shift β(y) from condition
  • Apply after normalization: AdaIN(x,y) = γ(y)·norm(x) + β(y)

Trade-offs:

  • Strong conditioning → less diversity
  • Weak conditioning → ignores condition
  • Need to balance based on application

Progressive GAN

Growing resolution during training

Start: 4×4 → 8×8 → … → 1024×1024

Progressive training schedule:

  1. Train 4×4 GAN to convergence
  2. Add layers for 8×8, blend smoothly
  3. Fade in new layers with α ∈ [0,1]
  4. Continue to higher resolutions

Smooth transition:

# Alpha increases from 0 to 1
low_res = upsample(prev_layer)
high_res = new_conv_layer(prev_layer)
output = (1-alpha)*low_res + alpha*high_res

Benefits:

  • Stable at high resolution
  • 2-6× faster training
  • Progressive refinement of details

StyleGAN Innovations

Style-based generator

1. Mapping network:

  • z → w (8 FC layers)
  • w-space more disentangled than z-space
  • Better semantic control

2. Style injection via AdaIN:

  • Each layer gets style from w
  • Controls features at that scale
  • Coarse layers: pose, shape
  • Fine layers: colors, textures

3. Stochastic variation:

  • Add noise at each layer
  • Controls fine details (hair, freckles)
  • Doesn’t affect overall structure

Perceptual path length:

  • Regularize to make w-space linear
  • Smoother interpolations

CycleGAN: Unpaired Translation

Problem: No paired training data

Want: horses ↔︎ zebras, summer ↔︎ winter

CycleGAN solution:

  • Two generators: G: X→Y, F: Y→X
  • Two discriminators: D_X, D_Y
  • Cycle consistency: F(G(x)) ≈ x

Objectives:

Adversarial loss: \[\mathcal{L}_{\text{GAN}}(G, D_Y) = \mathbb{E}_y[\log D_Y(y)] + \mathbb{E}_x[\log(1-D_Y(G(x)))]\]

Cycle consistency loss: \[\mathcal{L}_{\text{cyc}} = \mathbb{E}_x[||F(G(x)) - x||_1] + \mathbb{E}_y[||G(F(y)) - y||_1]\]

Total: \(\mathcal{L} = \mathcal{L}_{\text{GAN}} + \lambda \mathcal{L}_{\text{cyc}}\)

CycleGAN Results and Issues

What works well:

  • Style transfer (photo ↔︎ painting)
  • Season transfer (summer ↔︎ winter)
  • Object transfiguration (horse ↔︎ zebra)

Common failure modes:

  1. Semantic changes:

    • Dog → cat might change pose
    • Loses semantic correspondence
  2. Mode collapse in cycles:

    • All inputs map to same output
    • Then map back correctly
  3. Color/texture bias:

    • Focuses on easy changes
    • Ignores structural changes

Computational cost:

  • 4 networks to train
  • 2× memory of standard GAN

Other Important Architectures

BigGAN (2018)

  • Class-conditional at scale
  • Batch size 2048 (vs typical 64)
  • Spectral normalization + self-attention
  • Trade quality for diversity with “truncation trick”

StyleGAN2/3 (2019/2021)

  • Fixed artifacts in StyleGAN
  • Weight demodulation instead of AdaIN
  • Path length regularization
  • StyleGAN3: alias-free (rotation equivariant)

Diffusion-GAN Hybrids

  • Use GAN discriminator with diffusion
  • Faster sampling than pure diffusion
  • Better mode coverage than pure GAN

Computational requirements:

  • BigGAN: 128 TPUs × 48 hours = 6,144 TPU-hours
  • StyleGAN2: 8 V100s × 9 days = 1,728 GPU-hours
  • Memory: 32-64GB for large batch training
  • Parameters: BigGAN: 158M, StyleGAN2: 30M

Pix2Pix: Conditional Image Translation

Architecture (Isola et al. 2017):

Generator: U-Net with skip connections

  • Encoder: Conv → BN → LeakyReLU
  • Decoder: TransposeConv → BN → ReLU
  • Skip connections preserve detail

Discriminator: PatchGAN (70×70 receptive field)

  • Classifies overlapping patches as real/fake
  • Parameters: 2.7M (vs 41M for full image)

Objective: \[\mathcal{L} = \mathcal{L}_{cGAN}(G,D) + \lambda \mathcal{L}_{L1}(G)\]

where: \[\mathcal{L}_{L1}(G) = \mathbb{E}_{x,y,z}[||y - G(x,z)||_1]\]

λ = 100 typical (balances adversarial vs reconstruction)

Performance (256×256 images):

Task PSNR SSIM FID Time/img
Maps→Aerial 21.2 0.42 45.3 22ms
Edges→Photo 18.8 0.38 62.1 22ms
Day→Night 19.5 0.51 38.9 22ms

Memory requirements:

  • Training: 12GB (batch=1)
  • Inference: 2GB
  • Dataset: 400 image pairs minimum

GAN Evaluation

Inception Score Limitations

Inception Score (IS):

\[\text{IS}(G) = \exp\left(\mathbb{E}_{\mathbf{x} \sim p_g}\left[\text{KL}(p(y|\mathbf{x}) || p(y))\right]\right)\]

where:

  • \(p(y|\mathbf{x})\) = Inception network’s class predictions
  • \(p(y)\) = marginal class distribution

What IS measures:

  • High confidence predictions (quality)
  • Diverse predictions across samples (diversity)
  • Range: 1 (worst) to #classes (best)

Inception Score problems:

  1. Mode collapse increases IS

    • Perfect single class → maximum confidence
    • Missing modes not penalized
  2. ImageNet-specific

    • Meaningless for other domains
    • Cannot evaluate unconditional generation

Inception Score doesn’t measure quality:

  • Mode collapse produces higher scores
  • Optimizing IS leads to Inception network overfitting

Fréchet Inception Distance

FID measures distribution similarity:

\[\text{FID} = ||\boldsymbol{\mu}_r - \boldsymbol{\mu}_g||^2 + \text{Tr}(\boldsymbol{\Sigma}_r + \boldsymbol{\Sigma}_g - 2\sqrt{\boldsymbol{\Sigma}_r \boldsymbol{\Sigma}_g})\]

where:

  • \(\boldsymbol{\mu}_r, \boldsymbol{\Sigma}_r\) = real data statistics
  • \(\boldsymbol{\mu}_g, \boldsymbol{\Sigma}_g\) = generated data statistics
  • Computed from Inception-v3 pool3 layer (2048-d)

Assumptions:

  • Features are Gaussian distributed
  • Inception features capture perceptual similarity

Requirements:

  • Minimum 10k samples (50k recommended)
  • Same preprocessing as Inception training
  • Lower is better (0 = identical distributions)

FID correlates with human judgment:

  • Captures both quality and diversity
  • Sensitive to mode collapse
  • More robust than IS

Sample Size Requirements

FID variance depends on sample size:

With N samples:

  • Variance ∝ 1/N
  • Bias decreases with N
  • Need N > 10k for stable estimates

Recommended sample sizes:

  • Quick evaluation: 10k
  • Paper results: 50k
  • Final comparison: 100k+

Computational cost:

  • Extract features: ~1 min per 10k images
  • Compute statistics: negligible
  • Total: Linear in sample size

Common mistakes:

  • Using different N for real/fake
  • Different preprocessing
  • Wrong Inception model version

Training Tricks That Matter

Spectral Normalization

# Normalize weights by largest singular value
W_sn = W / sigma_max(W)
  • Enforces Lipschitz constraint
  • More stable than batch norm for \(D\)
  • 5-10% FID improvement

Self-Attention Layers

# At 32×32 resolution
attention = softmax(Q @ K.T / sqrt(d)) @ V
  • Captures long-range dependencies
  • Add to both G and D
  • 15-20% FID improvement

Moving Average Generator

# Exponential moving average
G_ema = beta * G_ema + (1-beta) * G
  • β = 0.999 typical
  • Use G_ema for generation
  • Reduces variance in outputs

Other useful tricks:

  • R1/R2 regularization for \(D\)
  • Differentiable augmentation
  • Truncation trick (trade quality/diversity)
  • Learning rate scheduling

GANs Beyond Computer Vision

Non-Vision Applications

Text Generation (SeqGAN, LeakGAN):

  • Discrete sequence handling via reinforcement learning
  • BLEU scores: 0.85 on short text (< 50 tokens)
  • Challenge: Non-differentiable sampling

Audio Synthesis (WaveGAN, MelGAN):

  • WaveGAN: 16kHz audio, 1 sec clips
  • MelGAN: 22kHz, real-time synthesis (0.03 RTF)
  • Memory: 200MB model, 4GB training

Molecular Generation (MolGAN, ChemGAN):

  • Valid molecules: 95% (ZINC dataset)
  • Novel molecules: 82% not in training
  • QED scores: 0.72 average drug-likeness

Tabular Data (CTGAN, TGAN):

# Handling mixed data types
continuous_cols = ['age', 'income']
categorical_cols = ['education', 'occupation']

# Mode-specific normalization
# Gaussian mixture for continuous
# One-hot + embedding for categorical

Performance comparison:

Domain GAN Type Metric Score
Text SeqGAN BLEU-4 0.85
Audio MelGAN MOS 4.2/5
Molecules MolGAN Validity 95%
Tabular CTGAN F1 0.89

Implementation Challenges Beyond Images

Discrete Sequences (Text/Code):

Problem: Sampling is non-differentiable

# Can't backprop through argmax
tokens = torch.argmax(logits, dim=-1)  # Breaks gradients

Solutions:

  1. REINFORCE: Treat as RL problem

    • Reward: Discriminator score
    • Variance reduction critical
  2. Gumbel-Softmax: Continuous relaxation

    # Temperature-controlled soft sampling
    soft_tokens = F.gumbel_softmax(logits, tau=0.5)
  3. Continuous embeddings: Skip discrete entirely

Time Series (Audio/Finance):

Challenges:

  • Long-range dependencies (10k+ timesteps)
  • Multiple timescales
  • Causality constraints

Architecture modifications:

# Dilated convolutions for receptive field
layers = [
    Conv1d(dilation=2**i) 
    for i in range(10)
]
# Receptive field: 1024 timesteps

Memory requirements:

  • Audio (1 min): 2.6M samples @ 44.1kHz
  • Memory: O(sequence_length × batch)
  • Solution: Sliding window generation

Success Metrics by Domain

Domain-Specific Evaluation:

Images:

  • FID, IS, LPIPS (perceptual similarity)
  • Human evaluation still gold standard

Text:

  • BLEU, ROUGE (n-gram overlap)
  • Perplexity from pretrained LM
  • Semantic similarity (BERT embeddings)

Audio:

  • MOS (Mean Opinion Score)
  • PESQ (Perceptual Evaluation)
  • Mel-cepstral distortion

Molecules:

  • Validity (parseable SMILES)
  • Uniqueness, Novelty
  • Drug-likeness (QED score)
  • Synthesizability (SA score)

Variational Autoencoders

Latent Variable Models

Generative Model Setup

Joint distribution: \[p(\mathbf{x}, \mathbf{z}) = p(\mathbf{x}|\mathbf{z})p(\mathbf{z})\]

Components:

  • Prior: \(p(\mathbf{z}) = \mathcal{N}(0, \mathbf{I})\) (standard Gaussian)
  • Decoder: \(p(\mathbf{x}|\mathbf{z})\) (neural network with parameters θ)
  • Latent dim: typically 10-500 (vs data dim 784-150K)

What we want: \[p(\mathbf{x}) = \int p(\mathbf{x}|\mathbf{z})p(\mathbf{z})d\mathbf{z}\]

Problem: Integral intractable for neural network decoder

Contrast with EBMs:

  • EBMs: \(p(\mathbf{x}) = \frac{1}{Z}e^{-E(\mathbf{x})}\) where \(Z = \int e^{-E(\mathbf{x}')}d\mathbf{x}'\)
  • VAEs: Marginalize latent variables instead of normalizing energies
  • Both face intractable integrals, but VAEs solve via variational inference rather than MCMC

Need p(x) for: Maximum likelihood training, sampling, model comparison

Inference Problem

Posterior intractability:

\[p(\mathbf{z}|\mathbf{x}) = \frac{p(\mathbf{x}|\mathbf{z})p(\mathbf{z})}{p(\mathbf{x})} = \frac{p(\mathbf{x}|\mathbf{z})p(\mathbf{z})}{\int p(\mathbf{x}|\mathbf{z}')p(\mathbf{z}')d\mathbf{z}'}\]

Intractability both ways:

  • Forward: \(p(\mathbf{x})\) requires integrating over \(\mathbf{z}\)
  • Inverse: \(p(\mathbf{z}|\mathbf{x})\) requires \(p(\mathbf{x})\) in denominator

Standard solution: Variational Inference

  • Approximate p(z|x) with tractable q(z|x)
  • Minimize KL(q||p) → maximize ELBO
  • Choose q family: Gaussian with diagonal covariance

Variational Inference Principle

Goal: Find q(z|x) ≈ p(z|x)

Direct KL minimization: \[\text{KL}(q_{\phi}(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}|\mathbf{x})) = \mathbb{E}_{q_{\phi}}\left[\log \frac{q_{\phi}(\mathbf{z}|\mathbf{x})}{p(\mathbf{z}|\mathbf{x})}\right]\]

Problem: Requires p(z|x) which needs p(x)!

Solution: Rewrite using Bayes rule: \[\log p(\mathbf{x}) = \text{KL}(q_{\phi}(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}|\mathbf{x})) + \mathcal{L}(\theta, \phi; \mathbf{x})\]

where Evidence Lower BOund (ELBO): \[\mathcal{L}(\theta, \phi; \mathbf{x}) = \mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x})}[\log p_{\theta}(\mathbf{x}|\mathbf{z})] - \text{KL}(q_{\phi}(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}))\]

Why ELBO works:

  • KL ≥ 0 → ELBO ≤ log p(x)

  • Maximizing ELBO:

    1. Tightens bound on log p(x)
    2. Minimizes KL(q||p)
  • Tractable: only needs p(z) and p(x|z)

ELBO Derivation

VAE Architecture and ELBO Connection

Encoder Network \(q_{\phi}(\mathbf{z}|\mathbf{x})\):

  • Input: data \(\mathbf{x}\)
  • Output: parameters \(\boldsymbol{\mu}_{\phi}(\mathbf{x})\), \(\boldsymbol{\sigma}_{\phi}(\mathbf{x})\)
  • Distribution: \(q_{\phi}(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\boldsymbol{\mu}_{\phi}(\mathbf{x}), \text{diag}(\boldsymbol{\sigma}_{\phi}^2(\mathbf{x})))\)

Reparameterization sampling: \[\mathbf{z} = \boldsymbol{\mu}_{\phi}(\mathbf{x}) + \boldsymbol{\sigma}_{\phi}(\mathbf{x}) \odot \boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})\]

Decoder Network \(p_{\theta}(\mathbf{x}|\mathbf{z})\):

  • Input: latent code \(\mathbf{z}\)
  • Output: reconstruction parameters
  • For images: \(p_{\theta}(\mathbf{x}|\mathbf{z}) = \mathcal{N}(\boldsymbol{\mu}_{\theta}(\mathbf{z}), \sigma^2 \mathbf{I})\)

ELBO Loss Computation: \[\mathcal{L} = \underbrace{\log p_{\theta}(\mathbf{x}|\mathbf{z})}_{\text{Decoder output}} - \underbrace{\text{KL}(q_{\phi}(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}))}_{\text{Encoder regularity}}\]

Connections:

  • Mathematical q_φ(z|x)Neural network outputting μ, σ parameters
  • Mathematical p_θ(x|z)Neural network computing reconstruction likelihood
  • ELBO termsConcrete loss functions with network gradients
  • Sampling zReparameterization trick for gradient flow

Evidence Lower Bound Derivation

Start with log marginal: \[\log p(\mathbf{x}) = \log \int p(\mathbf{x}, \mathbf{z}) d\mathbf{z}\]

Introduce q(z|x) via importance sampling: \[\log p(\mathbf{x}) = \log \int \frac{p(\mathbf{x}, \mathbf{z})}{q_{\phi}(\mathbf{z}|\mathbf{x})} q_{\phi}(\mathbf{z}|\mathbf{x}) d\mathbf{z}\]

Rewrite as expectation: \[\log p(\mathbf{x}) = \log \mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x})}\left[\frac{p(\mathbf{x}, \mathbf{z})}{q_{\phi}(\mathbf{z}|\mathbf{x})}\right]\]

Apply Jensen’s inequality (log is concave): \[\log \mathbb{E}[f(\mathbf{z})] \geq \mathbb{E}[\log f(\mathbf{z})]\]

Therefore: \[\log p(\mathbf{x}) \geq \mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x})}\left[\log \frac{p(\mathbf{x}, \mathbf{z})}{q_{\phi}(\mathbf{z}|\mathbf{x})}\right]\]

Expand joint p(x,z) = p(x|z)p(z): \[\mathcal{L} = \mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x})}[\log p_{\theta}(\mathbf{x}|\mathbf{z})] - \text{KL}(q_{\phi}(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}))\]

Result: ELBO = Reconstruction - Regularization

  • Reconstruction: \(\mathbb{E}[\log p(\mathbf{x}|\mathbf{z})]\) pushes \(q\) to encode \(\mathbf{x}\) well
  • Regularization: \(\text{KL}(q||p)\) keeps \(q(\mathbf{z}|\mathbf{x})\) close to prior \(p(\mathbf{z})\)

How this solves the EBM problem:

  • EBMs needed samples from \(p(\mathbf{x})\) via expensive MCMC
  • VAEs only need samples from \(q(\mathbf{z}|\mathbf{x})\), which we design to be simple (Gaussian)
  • Trade exact likelihood for tractable lower bound

ELBO Interpretation

Two forms of ELBO:

Form 1: Reconstruction - KL \[\mathcal{L} = \underbrace{\mathbb{E}_{q(\mathbf{z}|\mathbf{x})}[\log p(\mathbf{x}|\mathbf{z})]}_{\text{Reconstruction}} - \underbrace{\text{KL}(q(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}))}_{\text{Regularization}}\]

Form 2: Negative free energy \[\mathcal{L} = \mathbb{E}_{q(\mathbf{z}|\mathbf{x})}[\log p(\mathbf{x}, \mathbf{z})] + H[q(\mathbf{z}|\mathbf{x})]\]

where \(H\) is entropy of \(q\)

Information theory view:

  • Reconstruction: bits to encode \(\mathbf{x}\) given \(\mathbf{z}\)
  • KL term: extra bits to encode \(\mathbf{z}\) using \(q\) vs \(p\)
  • Total: compression bound on \(\mathbf{x}\)

Practical implications:

  • Small latent dim → high reconstruction error
  • Large latent dim → high KL penalty
  • Sweet spot: captures information with minimal bits

KL Divergence for Gaussians

Encoder outputs Gaussian q(z|x): \[q_{\phi}(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\boldsymbol{\mu}(\mathbf{x}), \text{diag}(\boldsymbol{\sigma}^2(\mathbf{x})))\]

Prior is standard Gaussian: \[p(\mathbf{z}) = \mathcal{N}(0, \mathbf{I})\]

Closed form KL: \[\text{KL}(q || p) = \frac{1}{2}\sum_{j=1}^J \left(\mu_j^2 + \sigma_j^2 - \log \sigma_j^2 - 1\right)\]

Derivation: For Gaussians \(\mathcal{N}(\boldsymbol{\mu}_1, \boldsymbol{\Sigma}_1)\) and \(\mathcal{N}(\boldsymbol{\mu}_2, \boldsymbol{\Sigma}_2)\): \[\text{KL} = \frac{1}{2}\left[\text{tr}(\boldsymbol{\Sigma}_2^{-1}\boldsymbol{\Sigma}_1) + (\boldsymbol{\mu}_2-\boldsymbol{\mu}_1)^T\boldsymbol{\Sigma}_2^{-1}(\boldsymbol{\mu}_2-\boldsymbol{\mu}_1) - k + \log\frac{|\boldsymbol{\Sigma}_2|}{|\boldsymbol{\Sigma}_1|}\right]\]

For our case: \(\boldsymbol{\mu}_2 = 0\), \(\boldsymbol{\Sigma}_2 = \mathbf{I}\), diagonal \(\boldsymbol{\Sigma}_1\)

What each term penalizes:

  • \(\mu_j^2\): Mean deviation from 0
  • \(\sigma_j^2\): Variance larger than 1
  • \(-\log \sigma_j^2\): Variance smaller than 1
  • Minimum at μ=0, σ=1 (matches prior)

Reparameterization Trick

Gradient Problem Through Sampling

Need gradient of expectation: \[\nabla_{\phi} \mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x})}[f(\mathbf{z})] = \nabla_{\phi} \int q_{\phi}(\mathbf{z}|\mathbf{x}) f(\mathbf{z}) d\mathbf{z}\]

Gradient w.r.t. encoder parameters φ that define q(z|x):

The gradient: \(\nabla_{\phi} \mathbb{E}_{z \sim q_{\phi}(z|x)}[\log p_{\theta}(x|z)]\)

Problem: Expectation over distribution depends on parameters φ. Creates high-variance gradient estimates.

Why naive sampling fails:

  • Backprop breaks at sampling z ~ q_φ(z|x)
  • Random sampling non-differentiable
  • Monte Carlo variance scales exponentially

Two approaches:

1. Score function estimator (REINFORCE): \[\nabla_{\phi} \mathbb{E}_{q_{\phi}}[f] = \mathbb{E}_{q_{\phi}}[f(\mathbf{z}) \nabla_{\phi} \log q_{\phi}(\mathbf{z}|\mathbf{x})]\]

Problem: Variance scales as \(O(e^D)\) with dimension D

2. Reparameterization trick: \[\mathbf{z} = \boldsymbol{\mu}_{\phi}(\mathbf{x}) + \boldsymbol{\sigma}_{\phi}(\mathbf{x}) \odot \boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})\]

Now: \(\nabla_{\phi} \mathbb{E}_{\boldsymbol{\epsilon}}[f(\mathbf{z})] = \mathbb{E}_{\boldsymbol{\epsilon}}[\nabla_{\phi} f(\mathbf{z})]\)

Why reparameterization works:

  • Moves stochasticity outside the network
  • Gradient flows through deterministic path
  • Variance reduced by factor of ~1000

Variance Analysis

Gradient variance comparison:

Score function gradient: \[\text{Var}[\nabla_{\phi}^{\text{SF}}] \approx \text{Var}[f(\mathbf{z})] \cdot \text{Var}[\nabla_{\phi} \log q_{\phi}]\]

  • Product of two variances
  • Scales exponentially with dimension
  • Requires 1000s of samples per gradient

Reparameterized gradient: \[\text{Var}[\nabla_{\phi}^{\text{Reparam}}] \approx \text{Var}_{\boldsymbol{\epsilon}}[\nabla_{\mathbf{z}} f(\mathbf{z})] \cdot ||\nabla_{\phi} \mathbf{z}||^2\]

  • Only gradient variance matters
  • Scales linearly with dimension
  • Works with single sample

Empirical comparison (MNIST VAE):

  • Score function: Var ≈ 10³ after 1K iterations
  • Reparameterization: Var ≈ 1 after 100 iterations

Practical impact:

  • Score function: unusable for \(D > 20\)
  • Reparameterization: works for \(D = 1000+\)
  • Enables deep latent variable models

Implementation Details

Forward pass:

def encode(x):
    h = encoder_network(x)
    mu = fc_mu(h)
    log_var = fc_logvar(h)  # Log variance for stability
    return mu, log_var

def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    z = mu + eps * std
    return z

def forward(x):
    mu, log_var = encode(x)
    z = reparameterize(mu, log_var)
    x_recon = decode(z)
    return x_recon, mu, log_var

Numerical stability tricks:

  • Output log σ² instead of σ² (prevents negative variance)
  • Clip log variance: log_var = torch.clamp(log_var, -10, 10)
  • Use log-sum-exp for stable likelihood computation

Gradient computation:

  • Automatic differentiation handles everything
  • Single sample per datapoint sufficient
  • Batch size 64-256 typical (vs 1-8 for score function)

Numerical Stability in Practice

Stable ELBO computation:

1. KL divergence for Gaussians:

def kl_divergence(mu, log_var):
    # Analytical KL for N(mu, sigma) || N(0, I)
    # Avoid computing sigma directly
    kl = -0.5 * torch.sum(
        1 + log_var - mu.pow(2) - log_var.exp(),
        dim=1
    )
    return kl

2. Reconstruction loss (Bernoulli):

def stable_bce_loss(x_logits, x_true):
    # Use logits directly, avoid sigmoid
    max_val = torch.clamp(x_logits, min=0)
    loss = x_logits - x_logits * x_true + max_val \
           + torch.log(torch.exp(-max_val) \
           + torch.exp(-x_logits - max_val))
    return loss.sum(dim=1)

3. Log-sum-exp for marginal likelihood:

def log_sum_exp(x, dim=0):
    max_x = torch.max(x, dim=dim, keepdim=True)[0]
    return max_x + torch.log(
        torch.sum(torch.exp(x - max_x), dim=dim)
    )

Common pitfalls and fixes:

  • Variance explosion: Clamp log_var to [-10, 10]
  • BCE with logits: Never compute sigmoid then log
  • Small KL: Use free bits or minimum KL threshold
  • Gradient explosion: Clip gradients by norm

VAE Architecture

Encoder-Decoder Architecture

Encoder network q(z|x):

class Encoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, 
                 latent_dim=20):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
        
    def forward(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc2_mu(h)
        log_var = self.fc2_logvar(h)
        # Clamp for numerical stability
        log_var = torch.clamp(log_var, min=-10, max=10)
        return mu, log_var

Decoder network p(x|z):

class Decoder(nn.Module):
    def __init__(self, latent_dim=20, hidden_dim=400,
                 output_dim=784):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, z):
        h = F.relu(self.fc1(z))
        x_recon = torch.sigmoid(self.fc2(h))  # For binary
        return x_recon

Computational advantage over EBMs:

  • Training: 1 forward + 1 backward pass per sample
  • EBMs: 1000+ forward passes for MCMC sampling
  • 100-1000× faster per gradient step

Design choices:

  • Hidden dimension: 2-5× latent dimension
  • Latent dimension: 10-50 for MNIST, 100-500 for images
  • Activation: ReLU in hidden, sigmoid/tanh for output
  • Weight initialization: Xavier/He normal

Output Distributions

Binary data (MNIST): \[p(\mathbf{x}|\mathbf{z}) = \prod_{i=1}^D \text{Bernoulli}(x_i | p_i)\]

Decoder outputs logits, loss = binary cross-entropy:

x_logits = decoder(z)  # No sigmoid
recon_loss = F.binary_cross_entropy_with_logits(
    x_logits, x, reduction='sum')

Continuous data (natural images): \[p(\mathbf{x}|\mathbf{z}) = \mathcal{N}(\boldsymbol{\mu}_{\theta}(\mathbf{z}), \sigma^2 \mathbf{I})\]

Decoder outputs mean, fixed variance:

x_mean = decoder(z)
# Gaussian with unit variance
recon_loss = F.mse_loss(x_mean, x, reduction='sum')
# Or learned variance
log_var = decoder_logvar(z)
recon_loss = gaussian_nll_loss(x, x_mean, log_var)

Impact on reconstruction:

  • Bernoulli: sharp but sometimes noisy
  • Gaussian (fixed σ²): blurry
  • Gaussian (learned σ²): can ignore details

Recommendation:

  • Binary/categorical: Bernoulli/Categorical
  • Continuous: Gaussian with fixed small σ²
  • Avoid learned variance unless necessary

Training Dynamics and Posterior Collapse

Posterior collapse phenomenon:

VAE ignores latent code: q(z|x) ≈ p(z) = N(0,I)

  • KL term → 0 (good?)
  • But z contains no information about x
  • Decoder becomes unconditional generator

Why it happens:

  1. Early training: Decoder can’t use z effectively
  2. Encoder minimizes KL by matching prior
  3. Local optimum: decoder ignores z, encoder outputs prior

Solutions:

1. KL annealing: Start with β=0, increase to 1

beta = min(1.0, epoch / warmup_epochs)
loss = recon_loss + beta * kl_loss

2. Free bits: Minimum KL per dimension

kl_loss = torch.max(kl_per_dim, 
                    torch.tensor(free_bits))

3. Decoder weakening: Dropout, smaller network

Monitoring collapse:

  • Track KL per dimension
  • Measure mutual information I(x;z)
  • Check reconstruction when sampling z ~ N(0,I)

VAE Variants

β-VAE for Disentanglement

Modified objective with β > 1: \[\mathcal{L}_{\beta} = \mathbb{E}_{q(\mathbf{z}|\mathbf{x})}[\log p(\mathbf{x}|\mathbf{z})] - \beta \cdot \text{KL}(q(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}))\]

What β controls:

  • β = 1: Standard VAE
  • β > 1: Stronger independence pressure
  • β >> 1: Forces factorized representation

Disentanglement metrics:

  • MIG (Mutual Information Gap): 0.15 → 0.45 with β=4
  • SAP (Separated Attribute Predictability): 0.20 → 0.60
  • DCI (Disentanglement, Completeness, Informativeness)

Trade-off: Reconstruction quality vs disentanglement

  • β=1: Reconstruction = -85, MIG = 0.15
  • β=4: Reconstruction = -95, MIG = 0.45
  • β=10: Reconstruction = -110, MIG = 0.50 (diminishing returns)

Applications:

  • Interpretable representations
  • Few-shot learning (better transfer)
  • Controllable generation

VQ-VAE: Discrete Latent Codes

Vector Quantization VAE:

Replace continuous z with discrete codes from codebook

Architecture:

  1. Encoder: \(\mathbf{z}_e = f_{enc}(\mathbf{x})\) (continuous)
  2. Quantization: \(\mathbf{z}_q = \text{nearest}(\mathbf{z}_e, \text{codebook})\)
  3. Decoder: \(\hat{\mathbf{x}} = f_{dec}(\mathbf{z}_q)\)

Codebook learning:

# K vectors of dimension D
codebook = nn.Embedding(num_codes=512, dim=64)

def quantize(z_e):
    # Find nearest codebook vector
    distances = torch.cdist(z_e, codebook.weight)
    indices = distances.argmin(dim=-1)
    z_q = codebook(indices)
    
    # Straight-through estimator
    z_q = z_e + (z_q - z_e).detach()
    return z_q, indices

Loss function: \[\mathcal{L} = \log p(\mathbf{x}|\mathbf{z}_q) + ||\text{sg}[\mathbf{z}_e] - \mathbf{e}||^2 + \beta ||\mathbf{z}_e - \text{sg}[\mathbf{e}]||^2\]

where sg = stop gradient, \(\mathbf{e}\) = codebook vectors

Advantages over continuous VAE:

  • True discrete representation (better for sequence modeling)
  • No posterior collapse
  • Can use powerful autoregressive priors
  • Compression rates: 8× → 512×

Hierarchical VAEs

Ladder VAE structure:

Multiple stochastic layers: \[p(\mathbf{x}, \mathbf{z}_1, ..., \mathbf{z}_L) = p(\mathbf{x}|\mathbf{z}_1)p(\mathbf{z}_1|\mathbf{z}_2)...p(\mathbf{z}_L)\]

Benefits:

  • Learn representations at multiple scales
  • Better gradient flow
  • Richer posteriors through lateral connections

Implementation approach:

class HierarchicalVAE(nn.Module):
    def encode(self, x):
        # Bottom-up pass
        h1 = self.enc1(x)
        z1_mu, z1_logvar = self.z1_params(h1)
        
        h2 = self.enc2(h1)
        z2_mu, z2_logvar = self.z2_params(h2)
        
        return [(z1_mu, z1_logvar), 
                (z2_mu, z2_logvar)]
    
    def decode(self, z_list):
        # Top-down generation
        h = self.dec2(z_list[1])
        h = self.dec1(torch.cat([h, z_list[0]], dim=1))
        return self.output(h)

Performance gains:

  • MNIST: -86.4 → -82.1 NLL
  • CIFAR-10: 3.51 → 3.28 bits/dim
  • Better samples without blur

Model Comparison

VAE vs GAN Trade-offs

Quantitative comparison:

Metric VAE GAN
Training stability Stable Unstable
Mode coverage Good Poor (collapse)
Sample quality Blurry Sharp
Likelihood Tractable bound None
Latent inference q(z|x) Requires BiGAN
Training time (CIFAR-10, V100) 12 hours 24-48 hours
Hyperparameter sensitivity Low High

Computational requirements (same architecture):

  • VAE: 1× forward + 1× backward per step
  • GAN: 2× forward + 2× backward (D and G)
  • Memory: VAE ≈ 0.5× GAN

Use VAE when: Need likelihood, stable training, latent inference Use GAN when: Need sharp samples, mode coverage less critical

Empirical results (CIFAR-10, single V100):

  • VAE: FID=52, IS=5.1, training=stable
  • GAN: FID=18, IS=7.8, training=unstable
  • VAE-GAN: FID=28, IS=6.5, training=semi-stable

Summary: Generative Model Landscape

  1. EBMs (2000s): Explicit density, intractable

    • Computational cost: Prohibitive
    • Led to: Score matching, contrastive methods
  2. VAEs (2014): Tractable lower bound

    • Solved EBM’s intractability via variational inference
    • Stable training, fast inference (no MCMC)
    • Limitation: Sample quality due to Gaussian assumptions
  3. GANs (2014): Implicit generation

    • High quality samples
    • Limitation: Training instability
  4. Modern (2020s): Best of both worlds

    • Diffusion models: Quality + likelihood
    • VAE-GAN hybrids
    • Discrete representations (VQ-VAE → DALL-E)

2024 methods:

  • Images: Diffusion models (DALL-E 2, Imagen)
  • Text: Autoregressive (GPT)
  • Audio: WaveNet, Diffusion
  • Multi-modal: CLIP + Diffusion

Open problems:

  • Efficiency: Faster sampling, lower memory
  • Control: Better conditioning, editing
  • Theory: Understanding why diffusion works
  • Scale: Trillion parameter generative models