Source code for GAN_CIFAR10_jax

import jax.numpy as jnp
from flax import linen as nn

[docs] class Discriminator(nn.Module): ''' Discriminator class for an unconditional GAN model. Applies several convolutional layers followed by LeakyReLU activations to classify the input. Methods: __call__: The forward pass through the network. '''
[docs] @nn.compact def __call__(self, x): ''' Forward pass through the discriminator. Args: x (jnp.ndarray): Input image tensor of shape (batch_size, height, width, channels). Returns: jnp.ndarray: A scalar score for each image in the batch. ''' x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Conv(features=128, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Conv(features=256, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Conv(features=512, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.leaky_relu(x, negative_slope=0.2) x = x.reshape((x.shape[0], -1)) # Flatten the feature map x = nn.Dense(features=1)(x) # Output a single scalar value return x
[docs] class Generator(nn.Module): ''' Generator class for an unconditional GAN model. Takes a latent code and generates images using deconvolutional layers. Methods: __call__: The forward pass through the network. '''
[docs] @nn.compact def __call__(self, z): ''' Forward pass through the generator. Args: z (jnp.ndarray): Input latent vector of shape (batch_size, latent_dim). Returns: jnp.ndarray: Generated image tensor of shape (batch_size, 32, 32, 3). ''' x = nn.Dense(features=4 * 4 * 512)(z) x = nn.leaky_relu(x, negative_slope=0.2) x = x.reshape((x.shape[0], 4, 4, 512)) # Reshape to start conv layers x = nn.ConvTranspose(features=256, kernel_size=(4, 4), strides=(2, 2), padding='SAME')(x) x = nn.leaky_relu(x, negative_slope=0.2) # Output shape: (batch_size, 8, 8, 256) x = nn.ConvTranspose(features=128, kernel_size=(4, 4), strides=(2, 2), padding='SAME')(x) x = nn.leaky_relu(x, negative_slope=0.2) # Output shape: (batch_size, 16, 16, 128) x = nn.ConvTranspose(features=64, kernel_size=(4, 4), strides=(2, 2), padding='SAME')(x) x = nn.leaky_relu(x, negative_slope=0.2) # Output shape: (batch_size, 32, 32, 64) x = nn.ConvTranspose(features=3, kernel_size=(3, 3), strides=(1, 1), padding='SAME')(x) x = nn.tanh(x) # Output shape: (batch_size, 32, 32, 3) return x
[docs] class Discriminator_cond(nn.Module): ''' Conditional Discriminator class for a GAN model. Takes both images and labels as input and discriminates between real and fake images conditioned on the labels. Methods: __call__: The forward pass through the network. '''
[docs] @nn.compact def __call__(self, x, labels): ''' Forward pass through the conditional discriminator. Args: x (jnp.ndarray): Input image tensor of shape (batch_size, height, width, channels). labels (jnp.ndarray): One-hot encoded labels of shape (batch_size, num_classes). Returns: jnp.ndarray: A scalar score for each image in the batch, conditioned on labels. ''' labels = labels.reshape((labels.shape[0], 1, 1, 10)) labels = jnp.broadcast_to(labels, (x.shape[0], x.shape[1], x.shape[2], 10)) x = jnp.concatenate([x, labels], axis=-1) x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Conv(features=128, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Conv(features=256, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Conv(features=512, kernel_size=(3, 3), strides=(2, 2))(x) x = nn.leaky_relu(x, negative_slope=0.2) x = x.reshape((x.shape[0], -1)) # Flatten the feature map x = nn.Dense(features=1)(x) # Output a single scalar value return x
[docs] class Generator_cond(nn.Module): ''' Conditional Generator class for a GAN model. Takes latent codes and labels to generate conditional images. Methods: __call__: The forward pass through the network. '''
[docs] @nn.compact def __call__(self, labels, z): ''' Forward pass through the conditional generator. Args: labels (jnp.ndarray): One-hot encoded labels of shape (batch_size, num_classes). z (jnp.ndarray): Latent vectors of shape (batch_size, latent_dim). Returns: jnp.ndarray: Generated image tensor of shape (batch_size, 32, 32, 3), conditioned on labels. ''' x = jnp.concatenate([z, labels], axis=-1) x = nn.Dense(features=4 * 4 * 512)(x) x = nn.leaky_relu(x, negative_slope=0.2) x = x.reshape((x.shape[0], 4, 4, 512)) # Reshape to start conv layers x = nn.ConvTranspose(features=256, kernel_size=(4, 4), strides=(2, 2), padding='SAME')(x) x = nn.leaky_relu(x, negative_slope=0.2) # Output shape: (batch_size, 8, 8, 256) x = nn.ConvTranspose(features=128, kernel_size=(4, 4), strides=(2, 2), padding='SAME')(x) x = nn.leaky_relu(x, negative_slope=0.2) # Output shape: (batch_size, 16, 16, 128) x = nn.ConvTranspose(features=64, kernel_size=(4, 4), strides=(2, 2), padding='SAME')(x) x = nn.leaky_relu(x, negative_slope=0.2) # Output shape: (batch_size, 32, 32, 64) x = nn.ConvTranspose(features=3, kernel_size=(3, 3), strides=(1, 1), padding='SAME')(x) x = nn.tanh(x) # Output shape: (batch_size, 32, 32, 3) return x