Source code for GAN_MNIST_jax

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
from jax import random
import optax


[docs] class Generator_MNIST_cond(nn.Module): latent_dim: int = 118 # Dimension of latent vector z num_classes: int = 10 # Number of classes for labels (MNIST has 10) @nn.compact def __call__(self, labels, z): """ Args: labels: One-hot encoded class labels, shape (batch_size, num_classes). z: Latent vector, shape (batch_size, latent_dim). Returns: Generated image, shape (batch_size, 28, 28, 1). """ # Concatenate latent vector z and labels x = jnp.concatenate([z, labels], axis=-1) # Fully connected layers with ReLU activations x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dense(512)(x) x = nn.relu(x) x = nn.Dense(1024)(x) x = nn.relu(x) x = nn.Dense(7 * 7 * 128)(x) x = nn.relu(x) # Reshape to 7x7 image with 128 channels x = x.reshape((-1, 7, 7, 128)) # Transposed convolution layers (to upscale the image) x = nn.ConvTranspose(features=64, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x) x = nn.relu(x) x = nn.ConvTranspose(features=32, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x) x = nn.relu(x) x = nn.ConvTranspose(features=1, kernel_size=(5, 5), strides=(1, 1), padding='SAME')(x) # Tanh activation to output images in range [-1, 1] x = nn.tanh(x) return x
[docs] class Discriminator_MNIST_cond(nn.Module): @nn.compact def __call__(self, x, labels, training: bool = True): """ Args: x: Input images, shape (batch_size, 28, 28, 1). labels: One-hot encoded class labels, shape (batch_size, num_classes). training: Whether the model is in training mode (for Dropout). Returns: Output logit, shape (batch_size, 1). """ # Tile the labels and concatenate with the input image along the channel axis labels = jnp.tile(labels[:, None, None, :], (1, 28, 28, 1)) x = jnp.concatenate([x, labels], axis=-1) # Convolutional layers with Leaky ReLU activations x = nn.Conv(features=64, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Dropout(0.3, deterministic=not training)(x) x = nn.Conv(features=128, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Dropout(0.3, deterministic=not training)(x) # Flatten the feature map x = x.reshape((x.shape[0], -1)) # Fully connected layer followed by output layer x = nn.Dense(1024)(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Dense(1)(x) # Output logit (real/fake classification) return x