model_jax module
- class model_jax.Discriminator(input_dim: int, spec_norm: bool, bounded: bool, layers_list: list, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
ModuleDiscriminator Class responsible for initializing and processing the discriminator network.
- Parameters:
input_dim (int) – Input dimension size.
spec_norm (bool) – Whether to apply spectral normalization to layers.
bounded (bool) – Whether to apply a bounded activation function to the final output.
layers_list (list) – List of integers where each integer specifies the number of units in each hidden layer.
- Returns:
Output after applying the discriminator network with an optional bounded activation.
- Return type:
jax.numpy.DeviceArray
- bounded: bool
- bounded_activation()[source]
Apply bounded activation using the tanh function to constrain outputs within [-M, M].
- Parameters:
x (jax.numpy.DeviceArray) – Input data.
- Returns:
Bounded output after applying tanh activation.
- Return type:
jax.numpy.DeviceArray
- input_dim: int
- layers_list: list
- name: str | None = None
- parent: Module | Scope | _Sentinel | None = None
- scope: Scope | None = None
- spec_norm: bool
- class model_jax.DiscriminatorMNIST(parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
ModuleDiscriminator for the MNIST dataset responsible for classifying real vs fake images.
- Returns:
Output after applying the discriminator network.
- Return type:
jax.numpy.DeviceArray
- name: str | None = None
- parent: Module | Scope | _Sentinel | None = None
- scope: Scope | None = None
- class model_jax.Generator(X_dim: int, Z_dim: int, spec_norm: bool, layers_list: list, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
ModuleGenerator Class responsible for initializing and processing the generator network.
- Parameters:
X_dim (int) – Dimension of the output generated by the generator.
Z_dim (int) – Dimension of the input latent space.
spec_norm (bool) – Whether to apply spectral normalization to layers.
layers_list (list) – List of integers where each integer specifies the number of units in each hidden layer.
- Returns:
Generated output after passing through the generator network.
- Return type:
jax.numpy.DeviceArray
- X_dim: int
- Z_dim: int
- layers_list: list
- name: str | None = None
- parent: Module | Scope | _Sentinel | None = None
- scope: Scope | None = None
- spec_norm: bool