import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dense, Input, Activation, LayerNormalization, LeakyReLU, SpectralNormalization
from tensorflow.keras.models import Sequential, Model
[docs]
class Discriminator(Model):
"""
Discriminator class responsible for initializing and processing the discriminator network.
Args:
input_dim (int): Dimension of the input features.
spec_norm (bool): Whether to apply spectral normalization to the layers.
bounded (bool): Whether to apply bounded activation on the final output.
layers_list (list): List of integers specifying the number of units for each hidden layer.
Methods:
call(inputs): Forward pass through the network.
bounded_activation(x): Activation function to apply bounded output using tanh.
Returns:
tf.Tensor: Output after processing through the discriminator network.
"""
def __init__(self, input_dim, spec_norm, bounded, layers_list):
super(Discriminator, self).__init__()
self.input_dim = input_dim
self.spec_norm = spec_norm
self.bounded = bounded
self.layers_list = layers_list
self.discriminator = Sequential()
self.discriminator.add(Input(shape=(self.input_dim,)))
if self.spec_norm:
for h_dim in self.layers_list:
self.discriminator.add(SpectralNormalization(Dense(units=h_dim, activation='relu')))
self.discriminator.add(SpectralNormalization(Dense(units=1, activation='linear')))
else:
for h_dim in self.layers_list:
self.discriminator.add(Dense(units=h_dim, activation='relu'))
self.discriminator.add(Dense(units=1, activation='linear'))
if bounded:
self.discriminator.add(Activation(self.bounded_activation))
print()
print('Discriminator Summary:')
self.discriminator.summary()
[docs]
def call(self, inputs):
"""
Forward pass through the discriminator.
Args:
inputs (tf.Tensor): Input data to the network.
Returns:
tf.Tensor: Discriminator's prediction after processing the inputs.
"""
predicted = self.discriminator(inputs)
return predicted
[docs]
def bounded_activation(x):
"""
Apply a bounded activation using the tanh function, constraining the output within [-M, M].
Args:
x (tf.Tensor): Input data.
Returns:
tf.Tensor: Output after applying bounded activation.
"""
M = 100.0
return M * K.tanh(x/M)
[docs]
class Generator(Model):
"""
Generator class responsible for initializing and processing the generator network.
Args:
X_dim (int): Dimension of the output generated by the generator.
Z_dim (int): Dimension of the input latent space.
spec_norm (bool): Whether to apply spectral normalization to the layers.
layers_list (list): List of integers specifying the number of units for each hidden layer.
Methods:
call(inputs): Forward pass through the network.
Returns:
tf.Tensor: Generated output after processing the latent input.
"""
def __init__(self, X_dim, Z_dim, spec_norm, layers_list):
super(Generator, self).__init__()
self.X_dim = X_dim
self.Z_dim = Z_dim
self.spec_norm = spec_norm
self.layers_list = layers_list
self.generator = Sequential()
self.generator.add(Input(shape=(self.Z_dim,)))
if self.spec_norm:
for h_dim in self.layers_list:
self.generator.add(SpectralNormalization(Dense(units=h_dim, activation='relu')))
self.generator.add(SpectralNormalization(Dense(units=X_dim, activation='linear')))
else:
for h_dim in self.layers_list:
self.generator.add(Dense(units=h_dim, activation='relu'))
self.generator.add(Dense(units=X_dim, activation='linear'))
print()
print('Generator Summary:')
self.generator.summary()
[docs]
def call(self, inputs):
"""
Forward pass through the generator.
Args:
inputs (tf.Tensor): Input latent vector.
Returns:
tf.Tensor: Generated output.
"""
predicted = self.generator(inputs)
return predicted
[docs]
class DiscriminatorMNIST(tf.keras.Model):
"""
Discriminator class for the MNIST dataset, responsible for classifying real vs fake images.
Methods:
call(x): Forward pass through the network.
Returns:
tf.Tensor: Output after processing the input image through the discriminator network.
"""
def __init__(self):
super(DiscriminatorMNIST, self).__init__()
self.linear0 = Dense(794)
self.linear1 = Dense(794)
self.linear2 = Dense(256)
self.linear3 = Dense(128)
self.linear4 = Dense(64)
self.linear5 = Dense(32)
self.linear6 = Dense(16)
self.linear7 = Dense(1)
self.ln0 = LayerNormalization()
self.ln1 = LayerNormalization()
self.ln2 = LayerNormalization()
self.ln3 = LayerNormalization()
self.ln4 = LayerNormalization()
self.ln5 = LayerNormalization()
self.relu = LeakyReLU(0.2)
[docs]
def call(self, x):
"""
Forward pass through the MNIST discriminator.
Args:
x (tf.Tensor): Input image data reshaped to (batch_size, 784).
Returns:
tf.Tensor: Output after processing through dense and normalization layers.
"""
x = tf.reshape(x, [-1, 784])
w = self.relu(self.ln0(self.linear0(x)))
w = self.relu(self.ln1(self.linear1(w)))
w = self.relu(self.ln2(self.linear2(w)))
w = self.relu(self.ln3(self.linear3(w)))
w = self.relu(self.ln4(self.linear4(w)))
w = self.relu(self.ln5(self.linear5(w)))
w = self.relu(self.linear6(w))
out = self.linear7(w)
return out