Conditional GANs in TensorFlow: Generating Targeted Synthetic Data

Conditional Generative Adversarial Networks (cGANs) are an advanced variant of GANs that allow the generation of synthetic data conditioned on specific inputs, such as class labels or attributes. This enables targeted data generation, making cGANs powerful for applications like image synthesis, data augmentation, and style transfer. In TensorFlow, the Keras API facilitates building cGANs by integrating conditional inputs into the generator and discriminator. This blog provides a comprehensive guide to cGANs, their mechanics, and practical implementation in TensorFlow, focusing on generating handwritten digits conditioned on class labels using the MNIST dataset. Designed to be detailed and natural, this guide covers data preprocessing, model design, training, and advanced techniques, ensuring you can create robust cGANs for conditional generative tasks.

Introduction to Conditional GANs

Traditional GANs generate data from random noise, producing diverse but uncontrolled outputs. cGANs extend this by conditioning the generator and discriminator on additional information, such as class labels, allowing the generation of specific data types (e.g., a digit “7” or a dog image). Introduced by Mirza and Osindero in 2014, cGANs modify the GAN framework by incorporating conditional inputs, enabling applications like labeled image generation or text-to-image synthesis.

In TensorFlow, cGANs are implemented using Keras layers like Concatenate to merge conditional inputs with noise or image data. We’ll build a Deep Convolutional cGAN (DCGAN) to generate 28x28 grayscale MNIST digits conditioned on digit labels (0–9). The MNIST dataset, containing 60,000 training and 10,000 test images, is ideal for this task. This guide assumes familiarity with GANs; for a primer, refer to Generative Adversarial Networks.

Mechanics of Conditional GANs

How Conditional GANs Work

A cGAN consists of two networks:

  • Generator (G): Takes a noise vector \( z \) and a conditional input (e.g., a class label \( y \)) to produce synthetic data \( G(z, y) \). It aims to generate data that fools the discriminator.
  • Discriminator (D): Takes data (real or fake) and the same conditional input \( y \) to predict whether the data is real (1) or fake (0). It evaluates if the data matches the condition.

The training objective is a modified minimax game:

[ \min_G \max_D V(D, G) = \mathbb{E}{x \sim p[\log (1 - D(G(z, y), y))] ]}}(x,y)}[\log D(x, y)] + \mathbb{E}_{z \sim p_z(z), y \sim p_y(y)

where ( x ) is real data, ( y ) is the condition, and ( z ) is noise. The generator learns to produce data aligned with the condition, while the discriminator ensures the data is both realistic and condition-appropriate.

Key Characteristics

  • Conditional Control: Generates data tailored to specific inputs, unlike unconditional GANs.
  • Adversarial Training: Maintains the competitive dynamic of GANs, with added complexity from conditions.
  • Challenges: Requires careful balancing of generator and discriminator, plus proper conditioning.

For more on building GANs, see Building GAN.

External Reference: Conditional Generative Adversarial Nets – Mirza and Osindero’s paper introducing cGANs.

Implementing a Conditional GAN in TensorFlow

We’ll build a DCGAN to generate MNIST digits conditioned on class labels (0–9). The generator will take noise and a one-hot encoded label, while the discriminator will evaluate images paired with their labels.

Step 1: Loading and Preprocessing the MNIST Dataset

Load the MNIST dataset, normalize images to [-1, 1], and prepare one-hot encoded labels for conditioning.

import tensorflow as tf
from tensorflow.keras.datasets import mnist
import numpy as np

# Load MNIST dataset
(x_train, y_train), (_, _) = mnist.load_data()

# Normalize and reshape images
x_train = x_train.astype('float32')
x_train = (x_train / 255.0) * 2 - 1  # Scale to [-1, 1]
x_train = x_train.reshape(-1, 28, 28, 1)

# One-hot encode labels
num_classes = 10
y_train = tf.keras.utils.to_categorical(y_train, num_classes)

# Create TensorFlow dataset
batch_size = 256
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).batch(batch_size)
  • Normalization: Scales pixel values to [-1, 1] for tanh output.
  • One-Hot Encoding: Converts labels (0–9) to 10-dimensional vectors.
  • Dataset: Pairs images and labels for training.

For more on loading datasets, see Loading Image Datasets.

External Reference: MNIST Dataset – Official MNIST dataset documentation.

Step 2: Building the Generator

The generator takes a 100-dimensional noise vector and a one-hot encoded label, combining them to generate a 28x28x1 image.

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Reshape, Conv2DTranspose, BatchNormalization, LeakyReLU, Concatenate

def build_generator(noise_dim=100, num_classes=10):
    noise_input = Input(shape=(noise_dim,))
    label_input = Input(shape=(num_classes,))
    x = Concatenate()([noise_input, label_input])
    x = Dense(7*7*256)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Reshape((7, 7, 256))(x)
    x = Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', activation='tanh')(x)
    return Model([noise_input, label_input], x)

generator = build_generator()
generator.summary()
  • Concatenate: Merges noise and label inputs.
  • Dense and Reshape: Maps to a 7x7x256 feature map.
  • Conv2DTranspose: Upsamples to 28x28x1.
  • BatchNormalization: Stabilizes training.
  • LeakyReLU: Adds non-linearity.
  • tanh: Outputs pixel values in [-1, 1].

For convolutional layers, see Convolution Operations.

Step 3: Building the Discriminator

The discriminator takes a 28x28x1 image and a one-hot encoded label, predicting whether the image is real or fake.

from tensorflow.keras.layers import Conv2D, Flatten, Dropout

def build_discriminator(num_classes=10):
    image_input = Input(shape=(28, 28, 1))
    label_input = Input(shape=(num_classes,))
    label_dense = Dense(28*28)(label_input)
    label_reshaped = Reshape((28, 28, 1))(label_dense)
    x = Concatenate()([image_input, label_reshaped])
    x = Conv2D(64, (5, 5), strides=(2, 2), padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.3)(x)
    x = Conv2D(128, (5, 5), strides=(2, 2), padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.3)(x)
    x = Flatten()(x)
    x = Dense(1, activation='sigmoid')(x)
    return Model([image_input, label_input], x)

discriminator = build_discriminator()
discriminator.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
                      loss='binary_crossentropy',
                      metrics=['accuracy'])
discriminator.summary()
  • Label Processing: Expands the label to match the image dimensions for concatenation.
  • Conv2D: Extracts features with downsampling.
  • LeakyReLU and Dropout: Prevent overfitting.
  • sigmoid: Outputs a real/fake probability.

For building CNNs, see Building CNN.

Step 4: Combining the cGAN

Combine the generator and discriminator, freezing the discriminator’s weights during generator training.

# Freeze discriminator weights during cGAN training
discriminator.trainable = False

# Define cGAN
noise_input = Input(shape=(100,))
label_input = Input(shape=(num_classes,))
gen_output = generator([noise_input, label_input])
cgan_output = discriminator([gen_output, label_input])
cgan = Model([noise_input, label_input], cgan_output)
cgan.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
             loss='binary_crossentropy')
cgan.summary()

Step 5: Training the cGAN

Train the discriminator and generator alternately, using real images with labels and fake images with corresponding labels.

import matplotlib.pyplot as plt

def train_cgan(epochs, batch_size, noise_dim=100, num_classes=10):
    real = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):
        for batch_images, batch_labels in dataset:
            # Train discriminator
            noise = np.random.normal(0, 1, (batch_size, noise_dim))
            gen_imgs = generator.predict([noise, batch_labels], verbose=0)
            d_loss_real = discriminator.train_on_batch([batch_images, batch_labels], real)
            d_loss_fake = discriminator.train_on_batch([gen_imgs, batch_labels], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # Train generator
            noise = np.random.normal(0, 1, (batch_size, noise_dim))
            sampled_labels = tf.keras.utils.to_categorical(np.random.randint(0, num_classes, batch_size), num_classes)
            g_loss = cgan.train_on_batch([noise, sampled_labels], real)

        # Display progress every 10 epochs
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, D Loss: {d_loss[0]:.4f}, D Acc: {d_loss[1]:.4f}, G Loss: {g_loss:.4f}")
            # Generate and plot sample images for digits 0-9
            noise = np.random.normal(0, 1, (10, noise_dim))
            labels = tf.keras.utils.to_categorical(np.arange(0, 10), num_classes)
            gen_imgs = generator.predict([noise, labels], verbose=0)
            gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale to [0, 1]
            fig, axes = plt.subplots(2, 5, figsize=(15, 6))
            for i, ax in enumerate(axes.flat):
                ax.imshow(gen_imgs[i, :, :, 0], cmap='gray')
                ax.set_title(f"Digit {i}")
                ax.axis('off')
            plt.show()

# Train the cGAN
train_cgan(epochs=100, batch_size=batch_size)
  • Discriminator Training: Uses real images with labels (1) and fake images with labels (0).
  • Generator Training: Trains the generator to produce images that match the conditioned labels, fooling the discriminator.
  • Visualization: Generates one image per digit (0–9) every 10 epochs.

For training techniques, see Training Network.

Step 6: Generating Conditional Images

Generate digits for specific labels using the trained generator:

# Generate images for digits 0-4
noise = np.random.normal(0, 1, (5, 100))
labels = tf.keras.utils.to_categorical(np.arange(0, 5), num_classes)
gen_imgs = generator.predict([noise, labels])
gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale to [0, 1]

# Plot generated images
plt.figure(figsize=(15, 3))
for i in range(5):
    plt.subplot(1, 5, i+1)
    plt.imshow(gen_imgs[i, :, :, 0], cmap='gray')
    plt.title(f"Digit {i}")
    plt.axis('off')
plt.show()

For related generative tasks, see MNIST Classification.

Advanced cGAN Techniques

Wasserstein cGAN

Adapt Wasserstein GAN (WGAN) principles to cGANs for improved stability, using Wasserstein loss and weight clipping:

from tensorflow.keras.optimizers import RMSprop

# Modify discriminator for WGAN
discriminator.compile(optimizer=RMSprop(learning_rate=0.00005), loss='wgan_loss')

For more, see Model Optimization.

External Reference: Wasserstein GAN – Paper introducing WGANs, adaptable to cGANs.

Label Smoothing

Use smoothed labels to prevent the discriminator from becoming too confident:

real = np.random.uniform(0.9, 1.0, (batch_size, 1))
fake = np.random.uniform(0.0, 0.1, (batch_size, 1))

For more, see Neural Network Best Practices.

Embedding-Based Conditioning

Instead of one-hot labels, use an embedding layer for continuous or complex conditions (e.g., text descriptions):

label_input = Input(shape=(1,))
label_embedded = Embedding(num_classes, 50)(label_input)

For more, see Text-to-Image Synthesis.

Common Challenges and Solutions

Training Instability

cGANs are sensitive to hyperparameters. Use a low learning rate (0.0002), Adam with beta_1=0.5, or WGAN loss. Ensure the discriminator doesn’t overpower the generator by monitoring losses.

Mode Collapse

The generator may produce similar outputs for different labels. Increase generator capacity or use feature matching:

generator = Model([noise_input, label_input], x)  # Increase Dense units or layers

Overfitting

The discriminator may overfit to real data-label pairs. Increase dropout (used in discriminator) or apply data augmentation (Image Augmentation).

Computational Cost

cGANs are resource-intensive due to conditional inputs. Use GPUs or TPUs for faster training (TPU Acceleration).

External Reference: GANs in Action – Book covering cGAN training challenges and solutions.

Practical Applications

cGANs are versatile for conditional generation:

  • Labeled Image Generation: Generate specific classes ([Fashion MNIST](/tensorflow/projects/fashion-mnist)).
  • Text-to-Image Synthesis: Create images from descriptions ([Pix2Pix](/tensorflow/computer-vision/pix2pix)).
  • Data Augmentation: Augment labeled datasets ([Image Augmentation](/tensorflow/computer-vision/image-augmentation)).

External Reference: TensorFlow Models Repository – Pre-trained cGAN models.

Conclusion

Conditional GANs in TensorFlow enable targeted data generation, offering control over synthetic outputs through conditional inputs. By building a DCGAN for MNIST digit generation conditioned on labels and exploring advanced techniques like WGANs and embedding-based conditioning, you’ve gained practical skills in conditional generative modeling. The provided code and resources offer a foundation to experiment further, adapting cGANs to tasks like labeled image synthesis or text-to-image generation. With this guide, you’re equipped to harness cGANs for innovative deep learning projects.