model_jax module

class model_jax.Discriminator(input_dim: int, spec_norm: bool, bounded: bool, layers_list: list, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: Module

Discriminator Class responsible for initializing and processing the discriminator network.

Parameters:
  • input_dim (int) – Input dimension size.

  • spec_norm (bool) – Whether to apply spectral normalization to layers.

  • bounded (bool) – Whether to apply a bounded activation function to the final output.

  • layers_list (list) – List of integers where each integer specifies the number of units in each hidden layer.

Returns:

Output after applying the discriminator network with an optional bounded activation.

Return type:

jax.numpy.DeviceArray

bounded: bool
bounded_activation()[source]

Apply bounded activation using the tanh function to constrain outputs within [-M, M].

Parameters:

x (jax.numpy.DeviceArray) – Input data.

Returns:

Bounded output after applying tanh activation.

Return type:

jax.numpy.DeviceArray

input_dim: int
layers_list: list
name: str | None = None
parent: Module | Scope | _Sentinel | None = None
scope: Scope | None = None
spec_norm: bool
class model_jax.DiscriminatorMNIST(parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: Module

Discriminator for the MNIST dataset responsible for classifying real vs fake images.

Returns:

Output after applying the discriminator network.

Return type:

jax.numpy.DeviceArray

name: str | None = None
parent: Module | Scope | _Sentinel | None = None
scope: Scope | None = None
class model_jax.Generator(X_dim: int, Z_dim: int, spec_norm: bool, layers_list: list, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: Module

Generator Class responsible for initializing and processing the generator network.

Parameters:
  • X_dim (int) – Dimension of the output generated by the generator.

  • Z_dim (int) – Dimension of the input latent space.

  • spec_norm (bool) – Whether to apply spectral normalization to layers.

  • layers_list (list) – List of integers where each integer specifies the number of units in each hidden layer.

Returns:

Generated output after passing through the generator network.

Return type:

jax.numpy.DeviceArray

X_dim: int
Z_dim: int
layers_list: list
name: str | None = None
parent: Module | Scope | _Sentinel | None = None
scope: Scope | None = None
spec_norm: bool