Welcome to NeuroDiME’s documentation!

class Divergences_jax.DataLoader(data, batch_size, shuffle=True)[source]

DataLoader class for loading and batching data during training.

class Divergences_jax.Discriminator_Penalty(penalty_weight)[source]

Base class for implementing penalties on the discriminator during training. Enables the implementation of discriminator constraints to regularize the divergence objective.

evaluate(discriminator, x, y, params, batch_stats, key, labels=None, dropout_rng=None)[source]

Evaluates the penalty term. Should be overridden by subclasses.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • batch_stats – Additional statistics for the batch.

  • key – Random key for JAX RNG.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

None. (Subclasses should implement specific penalty evaluations.)

get_penalty_weight()[source]

Returns the weight of the penalty.

Returns:

Penalty weight.

set_penalty_weight(weight)[source]

Sets the weight of the penalty.

Parameters:

weight – New penalty weight.

class Divergences_jax.Divergence(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Base class for Divergence measures D(P||Q) between random variables x~P, y~Q. This parent class defines common parameters and functions for different divergence measures.

discriminate(x, params, vars, labels=None, dropout_rng=None)[source]

Discriminates between samples from x~P and y~Q using the discriminator model.

Parameters:
  • x – Input data to be discriminated.

  • params – Parameters of the discriminator model.

  • vars – Additional variables such as batch statistics.

  • labels – Optional labels for the input data.

  • dropout_rng – Optional dropout key for stochasticity in dropout layers.

Returns:

Tuple of the discriminator output and optional updated batch statistics.

discriminator_loss(x, y, params, vars, labels=None, dropout_rng=None)[source]

Computes the loss for the discriminator model.

Parameters:
  • x – Samples from P.

  • y – Samples from Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of discriminator loss and updated variables.

estimate(x, y, params, vars, labels=None, dropout_rng=None)[source]

Estimates the divergence between P and Q.

Parameters:
  • x – Samples from P.

  • y – Samples from Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of estimated divergence and updated variables.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Placeholder method for evaluating the variational formula of a specific divergence. Should be overridden by subclasses.

Parameters:
  • x – Samples from P.

  • y – Samples from Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

None.

gen_train_step(gen_state, disc_state, disc_vars, gen_vars, key, z, labels=None, dropout_rng=None)[source]

Performs a single training step for the generator.

Parameters:
  • gen_state – Generator optimizer state.

  • disc_state – Discriminator optimizer state.

  • disc_vars – Discriminator variables.

  • gen_vars – Generator variables.

  • key – Random key for JAX RNG.

  • z – Latent input to the generator.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Updated generator state and generator loss.

generator_loss(x, params, vars, labels=None, dropout_rng=None)[source]

Computes the loss for the generator model.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

get_batch_size()[source]

Returns the batch size.

get_discriminator()[source]

Returns the discriminator model.

get_learning_rate()[source]

Returns the learning rate.

get_no_epochs()[source]

Returns the number of training epochs.

set_batch_size(batch_size)[source]

Sets the batch size.

Parameters:

batch_size – New batch size.

set_discriminator(discriminator)[source]

Sets a new discriminator model.

Parameters:

discriminator – New discriminator model.

set_learning_rate(lr)[source]

Sets the learning rate.

Parameters:

lr – New learning rate.

set_no_epochs(epochs)[source]

Sets the number of training epochs.

Parameters:

epochs – New number of epochs.

train(data_P, data_Q, state, vars, save_estimates=True, labels=None, dropout_rng=None)[source]

Trains the model for a given number of epochs.

Parameters:
  • data_P – Data samples from distribution P.

  • data_Q – Data samples from distribution Q.

  • state – Discriminator optimizer state.

  • vars – Discriminator variables.

  • save_estimates – Whether to save divergence estimates.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of estimated divergences and losses for each epoch.

train_step(x, y, state, vars, key, labels=None, dropout_rng=None)[source]

Performs a single training step for the discriminator.

Parameters:
  • x – Samples from P.

  • y – Samples from Q.

  • state – Optimizer state.

  • vars – Additional discriminator variables.

  • key – Random key for JAX RNG.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Updated state and loss value for the current step.

class Divergences_jax.Gradient_Penalty_1Sided(penalty_weight, Lip_const)[source]

One-sided gradient penalty to enforce a constraint: Lipschitz constant <= Lip_const.

evaluate(discriminator, x, y, params, batch_stats, key, labels=None, dropout_rng=None)[source]

Computes the one-sided gradient penalty to enforce the Lipschitz constant constraint.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • batch_stats – Additional statistics for the batch.

  • key – Random key for JAX RNG.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

One-sided gradient penalty value.

get_Lip_constant()[source]

Returns the target Lipschitz constant.

Returns:

Lipschitz constant.

set_Lip_constant(L)[source]

Sets the target Lipschitz constant.

Parameters:

L – New Lipschitz constant.

class Divergences_jax.Gradient_Penalty_2Sided(penalty_weight, Lip_const)[source]

Two-sided gradient penalty to enforce a constraint: Lipschitz constant = Lip_const.

evaluate(discriminator, x, y, params, labels=None, dropout_rng=None)[source]

Computes the two-sided gradient penalty to enforce the Lipschitz constant constraint.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Two-sided gradient penalty value.

get_Lip_constant()[source]

Returns the target Lipschitz constant.

Returns:

Lipschitz constant.

set_Lip_constant(L)[source]

Sets the target Lipschitz constant.

Parameters:

L – New Lipschitz constant.

class Divergences_jax.IPM(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

IPM (Integral Probability Metrics) class, a subclass of Divergence. Evaluates the IPM between distributions P and Q using a variational formula.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for IPM.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for IPM when applied to a generator model.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

class Divergences_jax.Jensen_Shannon_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Jensen-Shannon divergence class based on the Legendre transform. JS(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = y * log(y) - (y + 1) * log((y + 1) / 2).

Parameters:

y – Input to the Legendre transform.

Returns:

Transformed value.

final_layer_activation(y)[source]

Final layer activation function for Jensen-Shannon divergence.

Parameters:

y – Input to the activation function.

Returns:

Activated output.

class Divergences_jax.KLD_DV(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

KL Divergence class based on the Donsker-Varadhan variational formula. KL(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for KL divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of KL divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for KL divergence applied to a generator.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

class Divergences_jax.KLD_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Kullback-Leibler (KL) Divergence class based on the Legendre transform. KL(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = y * log(y).

Parameters:

y – Input to the Legendre transform.

Returns:

Transformed value.

class Divergences_jax.Pearson_chi_squared_HCR(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Pearson chi-squared divergence class based on the Hammersley-Chapman-Robbins bound. chi^2(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for Pearson chi-squared divergence based on the Hammersley-Chapman-Robbins bound.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Pearson chi-squared divergence loss.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the generator’s objective based on Pearson chi-squared divergence.

class Divergences_jax.Pearson_chi_squared_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Pearson chi-squared divergence class based on the Legendre transform. chi^2(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = (y - 1)^2.

Parameters:

y – Input to the Legendre transform.

Returns:

Transformed value.

class Divergences_jax.Renyi_Divergence(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]

Renyi divergence class, a subclass of Divergence. R_alpha(P||Q), x~P, y~Q.

get_order()[source]

Returns the order of the Renyi divergence.

Returns:

Alpha order.

set_order(alpha)[source]

Sets the order of the Renyi divergence.

Parameters:

alpha – New alpha order.

class Divergences_jax.Renyi_Divergence_CC(discriminator, disc_optimizer, alpha, epochs, batch_size, final_act_func, discriminator_penalty=None)[source]

Renyi divergence class based on the convex-conjugate variational formula. R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of Renyi divergence using the convex-conjugate variational formula.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of Renyi divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of Renyi divergence for the generator using the convex-conjugate variational formula.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

final_layer_activation(y)[source]

Final layer activation function to enforce positive values.

Parameters:

y – Output of the discriminator.

Returns:

Activated output, ensuring positivity based on the final activation function.

class Divergences_jax.Renyi_Divergence_CC_rescaled(discriminator, disc_optimizer, alpha, epochs, batch_size, final_act_func, discriminator_penalty=None)[source]

Rescaled Renyi divergence class based on the rescaled convex-conjugate variational formula. alpha * R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of the rescaled Renyi divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of rescaled Renyi divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for the generator of the rescaled Renyi divergence.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

final_layer_activation(y)[source]

Final layer activation function to enforce positivity, scaled by alpha.

Parameters:

y – Output of the discriminator.

Returns:

Activated output, scaled by the alpha parameter.

class Divergences_jax.Renyi_Divergence_DV(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]

Renyi divergence class based on the Renyi-Donsker-Varadhan variational formula. R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of Renyi divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of Renyi divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of Renyi divergence for the generator.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

class Divergences_jax.Renyi_Divergence_WCR(discriminator, disc_optimizer, alpha, epochs, batch_size, final_act_func, discriminator_penalty=None)[source]

Rescaled Renyi divergence class as alpha approaches infinity (worst-case regret divergence). Dinfty(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of the Renyi divergence class as alpha approaches infinity (worst-case regret divergence).

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of worst-case regret divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for the generator of the worst-case regret divergence.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

class Divergences_jax.alpha_Divergence_LT(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]

Alpha-divergence class based on the Legendre transform. D_f_alpha(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f_alpha based on the alpha value.

Parameters:

y – Input to the Legendre transform.

Returns:

Transformed value.

get_order()[source]

Returns the order of the alpha-divergence.

Returns:

Alpha order.

set_order(alpha)[source]

Sets the order of the alpha-divergence.

Parameters:

alpha – New alpha order.

class Divergences_jax.f_Divergence(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

f-divergence class, parent class for f-divergence-based measures D_f(P||Q). Subclasses need to implement the Legendre transform of f (f_star).

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of f-divergence, D_f(P||Q).

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for f-divergence when applied to a generator.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

f_star(y)[source]

Placeholder for the Legendre transform of the function f. Should be implemented by subclasses.

Parameters:

y – Input to the Legendre transform.

Returns:

None.

final_layer_activation(y)[source]

Final activation function applied to the output of the discriminator.

Parameters:

y – Output of the discriminator.

Returns:

Activated output.

class Divergences_jax.squared_Hellinger_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Squared Hellinger distance class based on the Legendre transform. H(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = (sqrt(y) - 1)^2.

Parameters:

y – Input to the Legendre transform.

Returns:

Transformed value.

final_layer_activation(y)[source]

Final layer activation for squared Hellinger distance.

Parameters:

y – Input to the activation function.

Returns:

Activated output.

class Divergences_tf.Discriminator_Penalty(penalty_weight)[source]

Discriminator penalty class penalizes the divergence objective functional during training. Allows for the (approximate) implementation of discriminator constraints.

evaluate(discriminator, x, y, labels=None)[source]

Evaluates the penalty term. Should be overridden by subclasses.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

None. (Subclasses should implement penalty evaluation.)

get_penalty_weight()[source]

Returns the weight of the penalty term.

set_penalty_weight(weight)[source]

Sets the weight of the penalty term.

class Divergences_tf.Divergence(*args, **kwargs)[source]

Divergence D(P||Q) between random variables x~P, y~Q. Parent class where common parameters and functions are defined.

discriminate(x, labels=None)[source]

Discriminates between samples from distributions P and Q.

Parameters:
  • x – Input data to discriminate.

  • labels – Optional labels for the input data.

Returns:

Discriminator output.

discriminator_loss(x, y, labels=None)[source]

Computes the discriminator loss.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Discriminator loss.

estimate(x, y, labels=None)[source]

Estimates the divergence measure.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Estimated divergence loss.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for the divergence measure. Should be implemented by subclasses.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

None (to be overridden by subclasses).

generator_loss(x, labels=None)[source]

Computes the generator loss.

Parameters:
  • x – Generated samples.

  • labels – Optional labels for the input data.

Returns:

Generator loss.

get_batch_size()[source]

Returns the batch size.

get_discriminator()[source]

Returns the discriminator model.

get_learning_rate()[source]

Returns the learning rate.

get_no_epochs()[source]

Returns the number of training epochs.

set_batch_size(BATCH_SIZE)[source]

Sets the batch size.

set_discriminator(discriminator)[source]

Sets a new discriminator model.

set_learning_rate(lr)[source]

Sets the learning rate.

set_no_epochs(epochs)[source]

Sets the number of training epochs.

train(data_P, data_Q, labels=None, save_estimates=True)[source]

Trains the model for a number of epochs.

Parameters:
  • data_P – Data samples from distribution P.

  • data_Q – Data samples from distribution Q.

  • labels – Optional labels for the input data.

  • save_estimates – Whether to save divergence estimates.

Returns:

A list of divergence estimates for each epoch.

train_step(x, y, labels=None)[source]

Performs a training step for the discriminator.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Loss value for the current step.

class Divergences_tf.Gradient_Penalty_1Sided(penalty_weight, Lip_const)[source]

One-sided gradient penalty class to enforce the Lipschitz constant <= Lip_const.

evaluate(discriminator, x, y, labels=None)[source]

Computes the one-sided gradient penalty to enforce the Lipschitz constant <= Lip_const.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

One-sided gradient penalty value.

get_Lip_constant()[source]

Returns the target Lipschitz constant.

set_Lip_constant(L)[source]

Sets the target Lipschitz constant.

class Divergences_tf.Gradient_Penalty_2Sided(penalty_weight, Lip_const)[source]

Two-sided gradient penalty class to enforce the Lipschitz constant = Lip_const.

evaluate(discriminator, x, y, labels=None)[source]

Computes the two-sided gradient penalty to enforce the Lipschitz constant = Lip_const.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Two-sided gradient penalty value.

get_Lip_constant()[source]

Returns the target Lipschitz constant.

set_Lip_constant(L)[source]

Sets the target Lipschitz constant.

class Divergences_tf.IPM(*args, **kwargs)[source]

Integral Probability Metric (IPM) class for evaluating IPMs.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for IPM.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

IPM loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the variational formula for IPM applied to the generator.

Parameters:
  • x – Generated samples.

  • labels – Optional labels for the input data.

Returns:

Generator loss based on IPM.

class Divergences_tf.Jensen_Shannon_LT(*args, **kwargs)[source]

Jensen-Shannon divergence class based on the Legendre transform. JS(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = y * log(y) - (y + 1) * log((y + 1) / 2).

final_layer_activation(y)[source]

Applies the final layer activation for Jensen-Shannon divergence.

class Divergences_tf.KLD_DV(*args, **kwargs)[source]

KL divergence class based on the Donsker-Varadhan variational formula. KL(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for KL divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

KL divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on KL divergence.

class Divergences_tf.KLD_LT(*args, **kwargs)[source]

Kullback-Leibler (KL) divergence class based on the Legendre transform. KL(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = y * log(y).

class Divergences_tf.Pearson_chi_squared_HCR(*args, **kwargs)[source]

Pearson chi-squared divergence class based on the Hammersley-Chapman-Robbins bound. chi^2(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for Pearson chi-squared divergence based on the Hammersley-Chapman-Robbins bound.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Pearson chi-squared divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on Pearson chi-squared divergence.

class Divergences_tf.Pearson_chi_squared_LT(*args, **kwargs)[source]

Pearson chi-squared divergence class based on the Legendre transform. chi^2(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = (y - 1)^2.

class Divergences_tf.Renyi_Divergence(*args, **kwargs)[source]

Renyi divergence class for computing Renyi divergence R_alpha(P||Q), x~P, y~Q.

get_order()[source]

Returns the order of the Renyi divergence.

set_order(alpha)[source]

Sets the order of the Renyi divergence.

class Divergences_tf.Renyi_Divergence_CC(*args, **kwargs)[source]

Renyi divergence class based on the convex-conjugate variational formula. R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for Renyi divergence based on the convex-conjugate formula.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Renyi divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on Renyi divergence using the convex-conjugate formula.

final_layer_activation(y)[source]

Applies the final layer activation to enforce positive values.

class Divergences_tf.Renyi_Divergence_CC_rescaled(*args, **kwargs)[source]

Rescaled Renyi divergence class based on the rescaled convex-conjugate variational formula. alpha * R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for the rescaled Renyi divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Rescaled Renyi divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on rescaled Renyi divergence.

final_layer_activation(y)[source]

Applies the final layer activation and scales it by alpha.

class Divergences_tf.Renyi_Divergence_DV(*args, **kwargs)[source]

Renyi divergence class based on the Renyi-Donsker-Varadhan variational formula. R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for Renyi divergence based on the Renyi-Donsker-Varadhan formula.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Renyi divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on Renyi divergence.

class Divergences_tf.Renyi_Divergence_WCR(*args, **kwargs)[source]

Renyi divergence class as alpha approaches infinity (worst-case regret divergence). Dinfty(P||Q), where x ~ P and y ~ Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for worst-case regret divergence as alpha approaches infinity.

Parameters:
  • x (tf.Tensor) – Samples from distribution P.

  • y (tf.Tensor) – Samples from distribution Q.

  • labels (tf.Tensor, optional) – Optional labels for the input data.

Returns:

Worst-case regret divergence loss.

Return type:

tf.Tensor

eval_var_formula_gen(X, labels=None)[source]

Evaluates the generator’s objective based on worst-case regret divergence.

Parameters:
  • X (tf.Tensor) – Samples from distribution Q.

  • labels (tf.Tensor, optional) – Optional labels for the input data.

Returns:

Generator’s loss based on worst-case regret divergence.

Return type:

tf.Tensor

class Divergences_tf.alpha_Divergence_LT(*args, **kwargs)[source]

Alpha-divergence class based on the Legendre transform. D_{f_alpha}(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f_alpha based on the alpha value.

get_order()[source]

Returns the order of the alpha-divergence.

set_order(alpha)[source]

Sets the order of the alpha-divergence.

class Divergences_tf.f_Divergence(*args, **kwargs)[source]

f-divergence class, parent class for f-divergence-based measures D_f(P||Q). Subclasses must implement the Legendre transform f_star.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for f-divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

f-divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the variational formula for f-divergence applied to the generator.

Parameters:
  • x – Generated samples.

  • labels – Optional labels for the input data.

Returns:

Generator loss based on f-divergence.

f_star(y)[source]

Placeholder for the Legendre transform of f. Should be implemented by subclasses.

final_layer_activation(y)[source]

Applies the final layer activation.

class Divergences_tf.squared_Hellinger_LT(*args, **kwargs)[source]

Squared Hellinger distance class based on the Legendre transform. H(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = (sqrt(y) - 1)^2.

final_layer_activation(y)[source]

Applies the final layer activation for squared Hellinger distance.

class Divergences_torch.Discriminator_Penalty(penalty_weight)[source]

Discriminator penalty class penalizes the divergence objective functional during training. Allows for the (approximate) implementation of discriminator constraints.

evaluate(discriminator, x, y, labels=None)[source]

Evaluates the penalty term. Should be overridden by subclasses.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

None. (Subclasses should implement penalty evaluation.)

get_penalty_weight()[source]

Returns the weight of the penalty term.

set_penalty_weight(weight)[source]

Sets the weight of the penalty term.

class Divergences_torch.Divergence(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Divergence D(P||Q) between random variables x~P, y~Q. Parent class where common parameters and functions are defined.

discriminate(x, labels=None)[source]

Discriminates between samples from distributions P and Q.

Parameters:
  • x – Input data to discriminate.

  • labels – Optional labels for the input data.

Returns:

Discriminator output.

discriminator_loss(x, y, labels=None)[source]

Computes the discriminator loss.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Discriminator loss.

estimate(x, y, labels=None)[source]

Estimates the divergence measure.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Estimated divergence loss.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for the divergence measure. Should be implemented by subclasses.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

None (to be overridden by subclasses).

generator_loss(x, labels=None)[source]

Computes the generator loss.

Parameters:
  • x – Generated samples.

  • labels – Optional labels for the input data.

Returns:

Generator loss.

get_batch_size()[source]

Returns the batch size.

get_discriminator()[source]

Returns the discriminator model.

get_learning_rate()[source]

Returns the learning rate.

get_no_epochs()[source]

Returns the number of training epochs.

set_batch_size(BATCH_SIZE)[source]

Sets the batch size.

set_discriminator(discriminator)[source]

Sets a new discriminator model.

set_learning_rate(lr)[source]

Sets the learning rate.

set_no_epochs(epochs)[source]

Sets the number of training epochs.

train(data_P, data_Q, labels=None, save_estimates=True)[source]

Trains the model for a number of epochs.

Parameters:
  • data_P – Data samples from distribution P.

  • data_Q – Data samples from distribution Q.

  • labels – Optional labels for the input data.

  • save_estimates – Whether to save divergence estimates.

Returns:

A list of divergence estimates for each epoch.

train_step(x, y, labels=None)[source]

Performs a training step for the discriminator.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Loss value for the current step.

class Divergences_torch.Gradient_Penalty_1Sided(penalty_weight, Lip_const)[source]

One-sided gradient penalty class to enforce the Lipschitz constant <= Lip_const.

evaluate(discriminator, x, y, labels=None)[source]

Computes the one-sided gradient penalty to enforce the Lipschitz constant <= Lip_const.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

One-sided gradient penalty value.

get_Lip_constant()[source]

Returns the target Lipschitz constant.

set_Lip_constant(L)[source]

Sets the target Lipschitz constant.

class Divergences_torch.Gradient_Penalty_2Sided(penalty_weight, Lip_const)[source]

Two-sided gradient penalty class to enforce the Lipschitz constant = Lip_const.

evaluate(discriminator, x, y, labels=None)[source]

Computes the two-sided gradient penalty to enforce the Lipschitz constant = Lip_const.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Two-sided gradient penalty value.

get_Lip_constant()[source]

Returns the target Lipschitz constant.

set_Lip_constant(L)[source]

Sets the target Lipschitz constant.

class Divergences_torch.IPM(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

IPM class for evaluating Integral Probability Metrics (IPM).

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for IPM.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

IPM loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the variational formula for IPM applied to the generator.

Parameters:
  • x – Generated samples.

  • labels – Optional labels for the input data.

Returns:

Generator loss based on IPM.

class Divergences_torch.Jensen_Shannon_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Jensen-Shannon divergence class based on the Legendre transform. JS(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = y * log(y) - (y + 1) * log((y + 1) / 2).

final_layer_activation(y)[source]

Applies the final layer activation for Jensen-Shannon divergence.

class Divergences_torch.KLD_DV(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

KL divergence class based on the Donsker-Varadhan variational formula. KL(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for KL divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

KL divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on KL divergence.

class Divergences_torch.KLD_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Kullback-Leibler (KL) divergence class based on the Legendre transform. KL(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = y * log(y).

class Divergences_torch.Pearson_chi_squared_HCR(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Pearson chi-squared divergence class based on the Hammersley-Chapman-Robbins bound. chi^2(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for Pearson chi-squared divergence based on the Hammersley-Chapman-Robbins bound.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Pearson chi-squared divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on Pearson chi-squared divergence.

class Divergences_torch.Pearson_chi_squared_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Pearson chi-squared divergence class based on the Legendre transform. chi^2(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = (y - 1)^2.

class Divergences_torch.Renyi_Divergence(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]

Renyi divergence class for computing Renyi divergence R_alpha(P||Q), x~P, y~Q.

get_order()[source]

Returns the order of the Renyi divergence.

set_order(alpha)[source]

Sets the order of the Renyi divergence.

class Divergences_torch.Renyi_Divergence_CC(discriminator, disc_optimizer, alpha, epochs, batch_size, fl_act_func, discriminator_penalty=None)[source]

Renyi divergence class based on the convex-conjugate variational formula. R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for Renyi divergence based on the convex-conjugate formula.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Renyi divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on Renyi divergence using the convex-conjugate formula.

final_layer_activation(y)[source]

Applies the final layer activation to enforce positive values.

class Divergences_torch.Renyi_Divergence_CC_rescaled(discriminator, disc_optimizer, alpha, epochs, batch_size, fl_act_func, discriminator_penalty=None)[source]

Rescaled Renyi divergence class based on the rescaled convex-conjugate variational formula. alpha * R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for the rescaled Renyi divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Rescaled Renyi divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on rescaled Renyi divergence.

final_layer_activation(y)[source]

Applies the final layer activation and scales it by alpha.

class Divergences_torch.Renyi_Divergence_DV(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]

Renyi divergence class based on the Renyi-Donsker-Varadhan variational formula. R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for Renyi divergence based on the Renyi-Donsker-Varadhan formula.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Renyi divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on Renyi divergence.

class Divergences_torch.Renyi_Divergence_WCR(discriminator, disc_optimizer, alpha, epochs, batch_size, fl_act_func, discriminator_penalty=None)[source]

Renyi divergence class as alpha approaches infinity (worst-case regret divergence). Dinfty(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for worst-case regret divergence as alpha approaches infinity.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Worst-case regret divergence loss.

eval_var_formula_gen(X, labels=None)[source]

Evaluates the generator’s objective based on worst-case regret divergence.

class Divergences_torch.alpha_Divergence_LT(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]

Alpha-divergence class based on the Legendre transform. D_{f_alpha}(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f_alpha based on the alpha value.

get_order()[source]

Returns the order of the alpha-divergence.

set_order(alpha)[source]

Sets the order of the alpha-divergence.

class Divergences_torch.f_Divergence(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

f-divergence class, parent class for f-divergence-based measures D_f(P||Q). Subclasses must implement the Legendre transform f_star.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for f-divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

f-divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the variational formula for f-divergence applied to the generator.

Parameters:
  • x – Generated samples.

  • labels – Optional labels for the input data.

Returns:

Generator loss based on f-divergence.

f_star(y)[source]

Placeholder for the Legendre transform of f. Should be implemented by subclasses.

final_layer_activation(y)[source]

Applies the final layer activation.

class Divergences_torch.squared_Hellinger_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Squared Hellinger distance class based on the Legendre transform. H(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = (sqrt(y) - 1)^2.

final_layer_activation(y)[source]

Applies the final layer activation for squared Hellinger distance.

class GAN_CIFAR10_jax.Discriminator(parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Discriminator class for an unconditional GAN model. Applies several convolutional layers followed by LeakyReLU activations to classify the input.

__call__()[source]

The forward pass through the network.

class GAN_CIFAR10_jax.Discriminator_cond(parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Conditional Discriminator class for a GAN model. Takes both images and labels as input and discriminates between real and fake images conditioned on the labels.

__call__()[source]

The forward pass through the network.

class GAN_CIFAR10_jax.Generator(parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Generator class for an unconditional GAN model. Takes a latent code and generates images using deconvolutional layers.

__call__()[source]

The forward pass through the network.

class GAN_CIFAR10_jax.Generator_cond(parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Conditional Generator class for a GAN model. Takes latent codes and labels to generate conditional images.

__call__()[source]

The forward pass through the network.

class GAN_CIFAR10_tf.Discriminator(*args, **kwargs)[source]

Discriminator model, which evaluates whether inputs are real or fake.

call()[source]

Passes input images through the discriminator network.

build(input_shape)[source]

Build the discriminator. :param input_shape: The shape of the input tensor.

call(inputs, training=True)[source]

Forward pass of the discriminator. :param inputs: Input images. :type inputs: tf.Tensor :param training: Whether the model is training. :type training: bool

Returns:

Discriminator score for each image.

Return type:

tf.Tensor

class GAN_CIFAR10_tf.DiscriminatorBlock(*args, **kwargs)[source]

ResNet-style block for the discriminator model, with optional downsampling.

Parameters:
  • in_chans (int) – Number of input channels.

  • out_chans (int) – Number of output channels.

  • downsample (bool) – Whether to apply 2x downsampling.

  • first (bool) – Whether this is the first block in the discriminator.

call()[source]

Passes input through the discriminator block.

build(input_shape)[source]

Build the layer. :param input_shape: The shape of the input tensor.

call(inputs, training=False)[source]

Forward pass of the discriminator block. :param inputs: Input tensor. :type inputs: tf.Tensor

Returns:

Output tensor after passing through the block.

Return type:

tf.Tensor

class GAN_CIFAR10_tf.Generator(*args, **kwargs)[source]

Generator model that consists of a dense layer followed by multiple generator blocks.

call()[source]

Passes input latent vectors through the generator network.

build(input_shape)[source]

Build the generator. :param input_shape: The shape of the input tensor.

call(noise, training=True)[source]

Forward pass of the generator. :param noise: Latent vectors. :type noise: tf.Tensor :param training: If the model is training. :type training: bool

Returns:

Generated image tensor.

Return type:

tf.Tensor

class GAN_CIFAR10_tf.GeneratorBlock(*args, **kwargs)[source]

ResNet-style block for the generator model, with optional upsampling.

Parameters:
  • in_chans (int) – Number of input channels.

  • out_chans (int) – Number of output channels.

  • upsample (bool) – Whether to apply 2x upsampling.

call()[source]

Passes input through the generator block.

build(input_shape)[source]

Build the layer. :param input_shape: The shape of the input tensor.

call(inputs)[source]

Forward pass of the generator block.

Parameters:

inputs (tf.Tensor) – Input tensor.

Returns:

Output tensor after passing through the block.

Return type:

tf.Tensor

compute_output_shape(input_shape)[source]

Compute output shape of the block. :param input_shape: The shape of the input tensor.

Returns:

Tuple of the output shape.

classmethod from_config(config)[source]

Instantiate the block from a configuration. :param config: Configuration dictionary.

Returns:

An instance of GeneratorBlock.

Return type:

GeneratorBlock

get_config()[source]

Get the configuration of the block for serialization. :returns: Configuration dictionary. :rtype: dict

GAN_CIFAR10_tf.avg_pool2d(x)[source]

Implements a twice-differentiable 2x2 average pooling operation.

Parameters:

x (tf.Tensor) – Input tensor of shape (batch_size, height, width, channels).

Returns:

Averaged pooled tensor.

Return type:

tf.Tensor

class GAN_CIFAR10_torch.Discriminator[source]

Discriminator (or critic) model for GAN.

forward()[source]

Passes input images through the discriminator.

forward(*inputs, labels=None)[source]

Forward pass of the discriminator.

Parameters:

inputs (torch.Tensor) – Input image tensor.

Returns:

Scalar score for each image.

Return type:

torch.Tensor

class GAN_CIFAR10_torch.DiscriminatorBlock(in_chans, out_chans, downsample=False, first=False)[source]

ResNet-style block for the discriminator model, with optional downsampling.

Parameters:
  • in_chans (int) – Number of input channels.

  • out_chans (int) – Number of output channels.

  • downsample (bool) – Whether to apply 2x downsampling.

  • first (bool) – Whether this is the first block in the discriminator.

forward()[source]

Passes input through the discriminator block.

forward(*inputs)[source]

Forward pass of the discriminator block.

Parameters:

inputs (torch.Tensor) – Input tensor.

Returns:

Output tensor after passing through the block.

Return type:

torch.Tensor

class GAN_CIFAR10_torch.Discriminator_cond[source]

Conditional Discriminator (or critic) model for conditional GAN.

forward()[source]

Passes input images and labels through the discriminator.

forward(x, labels)[source]

Forward pass of the conditional discriminator.

Parameters:
  • x (torch.Tensor) – Input image tensor.

  • labels (torch.Tensor) – One-hot encoded labels.

Returns:

Scalar score for each image.

Return type:

torch.Tensor

class GAN_CIFAR10_torch.Generator[source]

Generator model that consists of linear layers followed by multiple generator blocks.

forward()[source]

Passes input latent vectors through the generator network.

forward(*inputs, labels=None)[source]

Forward pass of the generator.

Parameters:
  • inputs (torch.Tensor) – Latent vectors.

  • labels (torch.Tensor) – (Optional) Labels, if required.

Returns:

Generated image tensor.

Return type:

torch.Tensor

class GAN_CIFAR10_torch.GeneratorBlock(in_chans, out_chans, upsample=False)[source]

ResNet-style block for the generator model, with optional upsampling.

Parameters:
  • in_chans (int) – Number of input channels.

  • out_chans (int) – Number of output channels.

  • upsample (bool) – Whether to apply 2x upsampling.

forward()[source]

Passes input through the generator block.

forward(*inputs)[source]

Forward pass of the generator block.

Parameters:

inputs (torch.Tensor) – Input tensor.

Returns:

Output tensor after passing through the block.

Return type:

torch.Tensor

class GAN_CIFAR10_torch.Generator_cond[source]

Conditional Generator model for conditional GAN.

forward()[source]

Passes input latent vectors and labels through the generator.

forward(z, labels)[source]

Forward pass of the conditional generator.

Parameters:
  • z (torch.Tensor) – Latent vectors.

  • labels (torch.Tensor) – One-hot encoded labels.

Returns:

Generated image tensor.

Return type:

torch.Tensor

GAN_CIFAR10_torch.avg_pool2d(x)[source]

Implements a twice-differentiable 2x2 average pooling operation.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, channels, height, width).

Returns:

Averaged pooled tensor.

Return type:

torch.Tensor

class GAN_MNIST_jax.Discriminator_MNIST_cond(parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7f0f496f5960>, name: Optional[str] = None)[source]
class GAN_MNIST_jax.Generator_MNIST_cond(latent_dim: int = 118, num_classes: int = 10, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7f0f496f5960>, name: Optional[str] = None)[source]
GAN_MNIST_tf.Discriminator_MNIST_cond()[source]

Conditional Discriminator model for MNIST dataset.

Returns:

A Keras model that takes as input an image and a one-hot encoded label and outputs a scalar value indicating real or fake.

Return type:

Model

Input:

x_input (Tensor): Input MNIST image of shape (batch_size, 28, 28). z_input (Tensor): One-hot encoded label of shape (batch_size, 10).

Output:

Tensor: Discriminator output scalar for each image in the batch.

GAN_MNIST_tf.Generator_MNIST_cond(latent_dim=118)[source]

Conditional Generator model for MNIST dataset.

Parameters:

latent_dim (int) – Dimension of the latent input vector. Default is 118.

Returns:

A Keras model that takes as input a latent vector and a one-hot encoded label and outputs a generated image.

Return type:

Model

Input:

label (Tensor): One-hot encoded label of shape (batch_size, 10). z (Tensor): Latent input vector of shape (batch_size, latent_dim).

Output:

Tensor: Generated image of shape (batch_size, 24, 24, 1).

class GAN_MNIST_torch.Discriminator_MNIST[source]

Unconditional Discriminator model for MNIST dataset.

forward(x)[source]

Classifies the input image as real or fake.

Input:

x (Tensor): Input MNIST image of shape (BATCH, 1, 28, 28).

Output:

Tensor: Discriminator output scalar for each image in the batch.

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class GAN_MNIST_torch.Discriminator_MNIST_cond[source]

Conditional Discriminator model for MNIST dataset.

forward(x, z)[source]

Classifies the input image and label as real or fake.

Input:

x (Tensor): Input MNIST image of shape (BATCH, 1, 28, 28). z (Tensor): One-hot encoded label tensor of shape (BATCH, 10).

Output:

Tensor: Discriminator output scalar for each image in the batch.

forward(x, z)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class GAN_MNIST_torch.Generator_MNIST(latent_dim=118)[source]

Unconditional Generator model for MNIST dataset.

Parameters:

latent_dim (int) – Dimension of the latent input vector. Default is 118.

forward(BATCH=16)[source]

Generates an image based on a random noise vector.

Input:

BATCH (int): Batch size for random latent vector generation. Default is 16.

Output:

Tensor: Generated image of shape (BATCH, 1, 28, 28).

forward(BATCH=16)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class GAN_MNIST_torch.Generator_MNIST_cond(latent_dim=118)[source]

Conditional Generator model for MNIST dataset.

Parameters:

latent_dim (int) – Dimension of the latent input vector. Default is 118.

forward(label, BATCH=16)[source]

Generates an image conditioned on the input label and a random noise vector.

Input:

label (Tensor): One-hot encoded label tensor of shape (BATCH, 10). BATCH (int): Batch size for random latent vector generation. Default is 16.

Output:

Tensor: Generated image of shape (BATCH, 1, 28, 28).

forward(label, BATCH=16)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class GAN_jax.GAN(divergence, generator, gen_optimizer, noise_source, epochs, disc_steps_per_gen_step, batch_size=None, reverse_order=False, include_penalty_in_gen_loss=False, cnn=False)[source]

Class for training a GAN using one of the provided divergences If reverse_order=False the GAN works to minimize min_theta D(P||g_theta(Z)) where P is the distribution to be leared, Z is the noise source and g_theta is the generator (with parameters theta). If reverse_order=True the GAN works to minimize min_theta D(g_theta(Z)||P) where P is the distribution to be leared, Z is the noise source and g_theta is the generator (with parameters theta).

disc_train_step(x, z, disc_params, disc_opt_state)[source]

discriminator’s parameters update

estimate_loss(x, z, gen_params)[source]

Estimating the loss

gen_train_step(x, z, gen_params, gen_opt_state)[source]

generator’s parameters update

train(data_P, disc_params, gen_params, disc_opt_state, gen_opt_state, save_frequency=None, num_gen_samples_to_save=None, save_loss_estimates=False)[source]

training function of our GAN

class GAN_tf.GAN(divergence, generator, gen_optimizer, noise_source, epochs, disc_steps_per_gen_step, batch_size=None, reverse_order=False, include_penalty_in_gen_loss=False)[source]

Class for training a GAN using one of the provided divergences If reverse_order=False the GAN works to minimize min_theta D(P||g_theta(Z)) where P is the distribution to be leared, Z is the noise source and g_theta is the generator (with parameters theta). If reverse_order=True the GAN works to minimize min_theta D(g_theta(Z)||P) where P is the distribution to be leared, Z is the noise source and g_theta is the generator (with parameters theta).

disc_train_step(x, z)[source]

discriminator’s parameters update.

estimate_loss(x, z)[source]

Estimating the loss

gen_train_step(x, z)[source]

generator’s parameters update.

train(data_P, save_frequency=None, num_gen_samples_to_save=None, save_loss_estimates=False)[source]

training function of our GAN

class GAN_torch.GAN(divergence, generator, gen_optimizer, noise_source, epochs, disc_steps_per_gen_step, batch_size=None, reverse_order=False, include_penalty_in_gen_loss=False)[source]

Class for training a GAN using one of the provided divergences If reverse_order=False the GAN works to minimize min_theta D(P||g_theta(Z)) where P is the distribution to be leared, Z is the noise source and g_theta is the generator (with parameters theta). If reverse_order=True the GAN works to minimize min_theta D(g_theta(Z)||P) where P is the distribution to be leared, Z is the noise source and g_theta is the generator (with parameters theta).

disc_train_step(x, z)[source]

discriminator’s parameters update

estimate_loss(x, z)[source]

Estimating the loss

gen_train_step(x, z)[source]

generator’s parameters update

train(data_P, save_frequency=None, num_gen_samples_to_save=None, save_loss_estimates=False)[source]

training function of our GAN

class model_jax.Discriminator(input_dim: int, spec_norm: bool, bounded: bool, layers_list: list, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Discriminator Class responsible for initializing and processing the discriminator network.

Parameters:
  • input_dim (int) – Input dimension size.

  • spec_norm (bool) – Whether to apply spectral normalization to layers.

  • bounded (bool) – Whether to apply a bounded activation function to the final output.

  • layers_list (list) – List of integers where each integer specifies the number of units in each hidden layer.

Returns:

Output after applying the discriminator network with an optional bounded activation.

Return type:

jax.numpy.DeviceArray

bounded_activation()[source]

Apply bounded activation using the tanh function to constrain outputs within [-M, M].

Parameters:

x (jax.numpy.DeviceArray) – Input data.

Returns:

Bounded output after applying tanh activation.

Return type:

jax.numpy.DeviceArray

class model_jax.DiscriminatorMNIST(parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Discriminator for the MNIST dataset responsible for classifying real vs fake images.

Returns:

Output after applying the discriminator network.

Return type:

jax.numpy.DeviceArray

class model_jax.Generator(X_dim: int, Z_dim: int, spec_norm: bool, layers_list: list, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Generator Class responsible for initializing and processing the generator network.

Parameters:
  • X_dim (int) – Dimension of the output generated by the generator.

  • Z_dim (int) – Dimension of the input latent space.

  • spec_norm (bool) – Whether to apply spectral normalization to layers.

  • layers_list (list) – List of integers where each integer specifies the number of units in each hidden layer.

Returns:

Generated output after passing through the generator network.

Return type:

jax.numpy.DeviceArray

class model_tf.Discriminator(*args, **kwargs)[source]

Discriminator class responsible for initializing and processing the discriminator network.

Parameters:
  • input_dim (int) – Dimension of the input features.

  • spec_norm (bool) – Whether to apply spectral normalization to the layers.

  • bounded (bool) – Whether to apply bounded activation on the final output.

  • layers_list (list) – List of integers specifying the number of units for each hidden layer.

call(inputs)[source]

Forward pass through the network.

bounded_activation(x)[source]

Activation function to apply bounded output using tanh.

Returns:

Output after processing through the discriminator network.

Return type:

tf.Tensor

bounded_activation()[source]

Apply a bounded activation using the tanh function, constraining the output within [-M, M].

Parameters:

x (tf.Tensor) – Input data.

Returns:

Output after applying bounded activation.

Return type:

tf.Tensor

call(inputs)[source]

Forward pass through the discriminator.

Parameters:

inputs (tf.Tensor) – Input data to the network.

Returns:

Discriminator’s prediction after processing the inputs.

Return type:

tf.Tensor

class model_tf.DiscriminatorMNIST(*args, **kwargs)[source]

Discriminator class for the MNIST dataset, responsible for classifying real vs fake images.

call(x)[source]

Forward pass through the network.

Returns:

Output after processing the input image through the discriminator network.

Return type:

tf.Tensor

call(x)[source]

Forward pass through the MNIST discriminator.

Parameters:

x (tf.Tensor) – Input image data reshaped to (batch_size, 784).

Returns:

Output after processing through dense and normalization layers.

Return type:

tf.Tensor

class model_tf.Generator(*args, **kwargs)[source]

Generator class responsible for initializing and processing the generator network.

Parameters:
  • X_dim (int) – Dimension of the output generated by the generator.

  • Z_dim (int) – Dimension of the input latent space.

  • spec_norm (bool) – Whether to apply spectral normalization to the layers.

  • layers_list (list) – List of integers specifying the number of units for each hidden layer.

call(inputs)[source]

Forward pass through the network.

Returns:

Generated output after processing the latent input.

Return type:

tf.Tensor

call(inputs)[source]

Forward pass through the generator.

Parameters:

inputs (tf.Tensor) – Input latent vector.

Returns:

Generated output.

Return type:

tf.Tensor

class model_torch.BoundedActivation(M=100.0)[source]

Bounded Activation Layer that applies a bounded activation function using tanh.

bounded_activation(x)

Apply the bounded activation to the input tensor.

forward(input)[source]

Forward pass to apply bounded activation.

forward(input)[source]

Forward pass through the bounded activation layer.

Parameters:

input (torch.Tensor) – Input tensor.

Returns:

Output after applying bounded activation.

Return type:

torch.Tensor

class model_torch.Discriminator(input_dim, batch_size, spec_norm, bounded, layers_list, device='cpu')[source]

Discriminator Class that initializes and processes the discriminator network.

Parameters:
  • input_dim (int) – Dimension of the input features.

  • batch_size (int) – Batch size for processing.

  • spec_norm (bool) – Whether to apply spectral normalization to the layers.

  • bounded (bool) – Whether to apply bounded activation on the final output.

  • layers_list (list) – List of integers specifying the number of units for each hidden layer.

  • device (str) – Device to run the model on (‘cpu’ or ‘cuda’).

forward(inputs)[source]

Forward pass through the network.

Returns:

Output after processing through the discriminator network.

Return type:

torch.Tensor

forward(inputs)[source]

Forward pass through the discriminator.

Parameters:

inputs (torch.Tensor) – Input data to the network.

Returns:

Discriminator’s prediction after processing the inputs.

Return type:

torch.Tensor

class model_torch.Discriminator_MNIST[source]

Discriminator Class for MNIST that processes the input and classifies real vs fake images.

forward(x)[source]

Forward pass through the network.

Returns:

Output after processing the input image through the discriminator network.

Return type:

torch.Tensor

forward(x)[source]

Forward pass through the MNIST discriminator.

Parameters:

x (torch.Tensor) – Input image data reshaped to (batch_size, 784).

Returns:

Output after processing through dense and normalization layers.

Return type:

torch.Tensor

class model_torch.Generator(X_dim, Z_dim, batch_size, spec_norm, layers_list, device='cpu')[source]

Generator Class that initializes and processes the generator network.

Parameters:
  • X_dim (int) – Dimension of the output generated by the generator.

  • Z_dim (int) – Dimension of the input latent space.

  • batch_size (int) – Batch size for processing.

  • spec_norm (bool) – Whether to apply spectral normalization to the layers.

  • layers_list (list) – List of integers specifying the number of units for each hidden layer.

  • device (str) – Device to run the model on (‘cpu’ or ‘cuda’).

forward(inputs)[source]

Forward pass through the network.

Returns:

Generated output after processing the latent input.

Return type:

torch.Tensor

forward(inputs)[source]

Forward pass through the generator.

Parameters:

inputs (torch.Tensor) – Input latent vector.

Returns:

Generated output.

Return type:

torch.Tensor

Indices and tables