Source code for model_jax

import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn


[docs] class Discriminator(nn.Module): """ Discriminator Class responsible for initializing and processing the discriminator network. Args: input_dim (int): Input dimension size. spec_norm (bool): Whether to apply spectral normalization to layers. bounded (bool): Whether to apply a bounded activation function to the final output. layers_list (list): List of integers where each integer specifies the number of units in each hidden layer. Returns: jax.numpy.DeviceArray: Output after applying the discriminator network with an optional bounded activation. """ input_dim: int spec_norm: bool bounded: bool layers_list: list
[docs] def bounded_activation(x): """ Apply bounded activation using the tanh function to constrain outputs within [-M, M]. Args: x (jax.numpy.DeviceArray): Input data. Returns: jax.numpy.DeviceArray: Bounded output after applying tanh activation. """ M = 100.0 return M * jnp.tanh(x / M)
@nn.compact def __call__(self, x): """ Forward pass through the discriminator. Args: x (jax.numpy.DeviceArray): Input data to the network. Returns: jax.numpy.DeviceArray: Discriminator's output after processing through dense layers. """ for h_dim in self.layers_list: if self.spec_norm: x = nn.SpectralNorm(nn.Dense(h_dim))(x) else: x = nn.Dense(h_dim)(x) x = nn.relu(x) if self.spec_norm: x = nn.SpectralNorm(nn.Dense(1))(x) else: x = nn.Dense(1)(x) if self.bounded: x = self.bounded_activation(x) return x
[docs] class Generator(nn.Module): """ Generator Class responsible for initializing and processing the generator network. Args: X_dim (int): Dimension of the output generated by the generator. Z_dim (int): Dimension of the input latent space. spec_norm (bool): Whether to apply spectral normalization to layers. layers_list (list): List of integers where each integer specifies the number of units in each hidden layer. Returns: jax.numpy.DeviceArray: Generated output after passing through the generator network. """ X_dim: int Z_dim: int spec_norm: bool layers_list: list @nn.compact def __call__(self, x): """ Forward pass through the generator. Args: x (jax.numpy.DeviceArray): Input latent vector to the network. Returns: jax.numpy.DeviceArray: Generated output. """ if self.spec_norm: for h_dim in self.layers_list: x = nn.SpectralNorm(nn.Dense(h_dim))(x) x = nn.relu(x) x = nn.SpectralNorm(nn.Dense(self.X_dim)) else: for h_dim in self.layers_list: x = nn.Dense(h_dim)(x) x = nn.relu(x) x = nn.Dense(self.X_dim)(x) return x
[docs] class DiscriminatorMNIST(nn.Module): """ Discriminator for the MNIST dataset responsible for classifying real vs fake images. Returns: jax.numpy.DeviceArray: Output after applying the discriminator network. """ @nn.compact def __call__(self, x): """ Forward pass through the MNIST-specific discriminator. Args: x (jax.numpy.DeviceArray): Input image data, reshaped to (batch_size, 784). Returns: jax.numpy.DeviceArray: Discriminator's output after passing through dense layers. """ x = x.reshape((-1, 784)) # Reshape to (batch_size, 784) x = nn.Dense(794)(x) x = nn.LayerNorm()(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Dense(794)(x) x = nn.LayerNorm()(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Dense(256)(x) x = nn.LayerNorm()(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Dense(128)(x) x = nn.LayerNorm()(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Dense(64)(x) x = nn.LayerNorm()(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Dense(32)(x) x = nn.LayerNorm()(x) x = nn.leaky_relu(x, negative_slope=0.2) x = nn.Dense(16)(x) x = nn.leaky_relu(x, negative_slope=0.2) # No LayerNorm here out = nn.Dense(1)(x) # Output layer return out