import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
[docs]
class Discriminator(nn.Module):
"""
Discriminator Class responsible for initializing and processing the discriminator network.
Args:
input_dim (int): Input dimension size.
spec_norm (bool): Whether to apply spectral normalization to layers.
bounded (bool): Whether to apply a bounded activation function to the final output.
layers_list (list): List of integers where each integer specifies the number of units in each hidden layer.
Returns:
jax.numpy.DeviceArray: Output after applying the discriminator network with an optional bounded activation.
"""
input_dim: int
spec_norm: bool
bounded: bool
layers_list: list
[docs]
def bounded_activation(x):
"""
Apply bounded activation using the tanh function to constrain outputs within [-M, M].
Args:
x (jax.numpy.DeviceArray): Input data.
Returns:
jax.numpy.DeviceArray: Bounded output after applying tanh activation.
"""
M = 100.0
return M * jnp.tanh(x / M)
@nn.compact
def __call__(self, x):
"""
Forward pass through the discriminator.
Args:
x (jax.numpy.DeviceArray): Input data to the network.
Returns:
jax.numpy.DeviceArray: Discriminator's output after processing through dense layers.
"""
for h_dim in self.layers_list:
if self.spec_norm:
x = nn.SpectralNorm(nn.Dense(h_dim))(x)
else:
x = nn.Dense(h_dim)(x)
x = nn.relu(x)
if self.spec_norm:
x = nn.SpectralNorm(nn.Dense(1))(x)
else:
x = nn.Dense(1)(x)
if self.bounded:
x = self.bounded_activation(x)
return x
[docs]
class Generator(nn.Module):
"""
Generator Class responsible for initializing and processing the generator network.
Args:
X_dim (int): Dimension of the output generated by the generator.
Z_dim (int): Dimension of the input latent space.
spec_norm (bool): Whether to apply spectral normalization to layers.
layers_list (list): List of integers where each integer specifies the number of units in each hidden layer.
Returns:
jax.numpy.DeviceArray: Generated output after passing through the generator network.
"""
X_dim: int
Z_dim: int
spec_norm: bool
layers_list: list
@nn.compact
def __call__(self, x):
"""
Forward pass through the generator.
Args:
x (jax.numpy.DeviceArray): Input latent vector to the network.
Returns:
jax.numpy.DeviceArray: Generated output.
"""
if self.spec_norm:
for h_dim in self.layers_list:
x = nn.SpectralNorm(nn.Dense(h_dim))(x)
x = nn.relu(x)
x = nn.SpectralNorm(nn.Dense(self.X_dim))
else:
for h_dim in self.layers_list:
x = nn.Dense(h_dim)(x)
x = nn.relu(x)
x = nn.Dense(self.X_dim)(x)
return x
[docs]
class DiscriminatorMNIST(nn.Module):
"""
Discriminator for the MNIST dataset responsible for classifying real vs fake images.
Returns:
jax.numpy.DeviceArray: Output after applying the discriminator network.
"""
@nn.compact
def __call__(self, x):
"""
Forward pass through the MNIST-specific discriminator.
Args:
x (jax.numpy.DeviceArray): Input image data, reshaped to (batch_size, 784).
Returns:
jax.numpy.DeviceArray: Discriminator's output after passing through dense layers.
"""
x = x.reshape((-1, 784)) # Reshape to (batch_size, 784)
x = nn.Dense(794)(x)
x = nn.LayerNorm()(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.Dense(794)(x)
x = nn.LayerNorm()(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.Dense(256)(x)
x = nn.LayerNorm()(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.Dense(128)(x)
x = nn.LayerNorm()(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.Dense(64)(x)
x = nn.LayerNorm()(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.Dense(32)(x)
x = nn.LayerNorm()(x)
x = nn.leaky_relu(x, negative_slope=0.2)
x = nn.Dense(16)(x)
x = nn.leaky_relu(x, negative_slope=0.2) # No LayerNorm here
out = nn.Dense(1)(x) # Output layer
return out