Divergences_jax module

class Divergences_jax.DataLoader(data, batch_size, shuffle=True)[source]

Bases: object

DataLoader class for loading and batching data during training.

class Divergences_jax.Discriminator_Penalty(penalty_weight)[source]

Bases: object

Base class for implementing penalties on the discriminator during training. Enables the implementation of discriminator constraints to regularize the divergence objective.

evaluate(discriminator, x, y, params, batch_stats, key, labels=None, dropout_rng=None)[source]

Evaluates the penalty term. Should be overridden by subclasses.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • batch_stats – Additional statistics for the batch.

  • key – Random key for JAX RNG.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

None. (Subclasses should implement specific penalty evaluations.)

get_penalty_weight()[source]

Returns the weight of the penalty.

Returns:

Penalty weight.

set_penalty_weight(weight)[source]

Sets the weight of the penalty.

Parameters:

weight – New penalty weight.

class Divergences_jax.Divergence(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Bases: object

Base class for Divergence measures D(P||Q) between random variables x~P, y~Q. This parent class defines common parameters and functions for different divergence measures.

discriminate(x, params, vars, labels=None, dropout_rng=None)[source]

Discriminates between samples from x~P and y~Q using the discriminator model.

Parameters:
  • x – Input data to be discriminated.

  • params – Parameters of the discriminator model.

  • vars – Additional variables such as batch statistics.

  • labels – Optional labels for the input data.

  • dropout_rng – Optional dropout key for stochasticity in dropout layers.

Returns:

Tuple of the discriminator output and optional updated batch statistics.

discriminator_loss(x, y, params, vars, labels=None, dropout_rng=None)[source]

Computes the loss for the discriminator model.

Parameters:
  • x – Samples from P.

  • y – Samples from Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of discriminator loss and updated variables.

estimate(x, y, params, vars, labels=None, dropout_rng=None)[source]

Estimates the divergence between P and Q.

Parameters:
  • x – Samples from P.

  • y – Samples from Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of estimated divergence and updated variables.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Placeholder method for evaluating the variational formula of a specific divergence. Should be overridden by subclasses.

Parameters:
  • x – Samples from P.

  • y – Samples from Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

None.

gen_train_step(gen_state, disc_state, disc_vars, gen_vars, key, z, labels=None, dropout_rng=None)[source]

Performs a single training step for the generator.

Parameters:
  • gen_state – Generator optimizer state.

  • disc_state – Discriminator optimizer state.

  • disc_vars – Discriminator variables.

  • gen_vars – Generator variables.

  • key – Random key for JAX RNG.

  • z – Latent input to the generator.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Updated generator state and generator loss.

generator_loss(x, params, vars, labels=None, dropout_rng=None)[source]

Computes the loss for the generator model.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

get_batch_size()[source]

Returns the batch size.

get_discriminator()[source]

Returns the discriminator model.

get_learning_rate()[source]

Returns the learning rate.

get_no_epochs()[source]

Returns the number of training epochs.

set_batch_size(batch_size)[source]

Sets the batch size.

Parameters:

batch_size – New batch size.

set_discriminator(discriminator)[source]

Sets a new discriminator model.

Parameters:

discriminator – New discriminator model.

set_learning_rate(lr)[source]

Sets the learning rate.

Parameters:

lr – New learning rate.

set_no_epochs(epochs)[source]

Sets the number of training epochs.

Parameters:

epochs – New number of epochs.

train(data_P, data_Q, state, vars, save_estimates=True, labels=None, dropout_rng=None)[source]

Trains the model for a given number of epochs.

Parameters:
  • data_P – Data samples from distribution P.

  • data_Q – Data samples from distribution Q.

  • state – Discriminator optimizer state.

  • vars – Discriminator variables.

  • save_estimates – Whether to save divergence estimates.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of estimated divergences and losses for each epoch.

train_step(x, y, state, vars, key, labels=None, dropout_rng=None)[source]

Performs a single training step for the discriminator.

Parameters:
  • x – Samples from P.

  • y – Samples from Q.

  • state – Optimizer state.

  • vars – Additional discriminator variables.

  • key – Random key for JAX RNG.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Updated state and loss value for the current step.

class Divergences_jax.Gradient_Penalty_1Sided(penalty_weight, Lip_const)[source]

Bases: Discriminator_Penalty

One-sided gradient penalty to enforce a constraint: Lipschitz constant <= Lip_const.

evaluate(discriminator, x, y, params, batch_stats, key, labels=None, dropout_rng=None)[source]

Computes the one-sided gradient penalty to enforce the Lipschitz constant constraint.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • batch_stats – Additional statistics for the batch.

  • key – Random key for JAX RNG.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

One-sided gradient penalty value.

get_Lip_constant()[source]

Returns the target Lipschitz constant.

Returns:

Lipschitz constant.

set_Lip_constant(L)[source]

Sets the target Lipschitz constant.

Parameters:

L – New Lipschitz constant.

class Divergences_jax.Gradient_Penalty_2Sided(penalty_weight, Lip_const)[source]

Bases: Discriminator_Penalty

Two-sided gradient penalty to enforce a constraint: Lipschitz constant = Lip_const.

evaluate(discriminator, x, y, params, labels=None, dropout_rng=None)[source]

Computes the two-sided gradient penalty to enforce the Lipschitz constant constraint.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Two-sided gradient penalty value.

get_Lip_constant()[source]

Returns the target Lipschitz constant.

Returns:

Lipschitz constant.

set_Lip_constant(L)[source]

Sets the target Lipschitz constant.

Parameters:

L – New Lipschitz constant.

class Divergences_jax.IPM(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Bases: Divergence

IPM (Integral Probability Metrics) class, a subclass of Divergence. Evaluates the IPM between distributions P and Q using a variational formula.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for IPM.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for IPM when applied to a generator model.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

class Divergences_jax.Jensen_Shannon_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Bases: f_Divergence

Jensen-Shannon divergence class based on the Legendre transform. JS(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = y * log(y) - (y + 1) * log((y + 1) / 2).

Parameters:

y – Input to the Legendre transform.

Returns:

Transformed value.

final_layer_activation(y)[source]

Final layer activation function for Jensen-Shannon divergence.

Parameters:

y – Input to the activation function.

Returns:

Activated output.

class Divergences_jax.KLD_DV(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Bases: Divergence

KL Divergence class based on the Donsker-Varadhan variational formula. KL(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for KL divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of KL divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for KL divergence applied to a generator.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

class Divergences_jax.KLD_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Bases: f_Divergence

Kullback-Leibler (KL) Divergence class based on the Legendre transform. KL(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = y * log(y).

Parameters:

y – Input to the Legendre transform.

Returns:

Transformed value.

class Divergences_jax.Pearson_chi_squared_HCR(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Bases: Divergence

Pearson chi-squared divergence class based on the Hammersley-Chapman-Robbins bound. chi^2(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for Pearson chi-squared divergence based on the Hammersley-Chapman-Robbins bound.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Pearson chi-squared divergence loss.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the generator’s objective based on Pearson chi-squared divergence.

class Divergences_jax.Pearson_chi_squared_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Bases: f_Divergence

Pearson chi-squared divergence class based on the Legendre transform. chi^2(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = (y - 1)^2.

Parameters:

y – Input to the Legendre transform.

Returns:

Transformed value.

class Divergences_jax.Renyi_Divergence(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]

Bases: Divergence

Renyi divergence class, a subclass of Divergence. R_alpha(P||Q), x~P, y~Q.

get_order()[source]

Returns the order of the Renyi divergence.

Returns:

Alpha order.

set_order(alpha)[source]

Sets the order of the Renyi divergence.

Parameters:

alpha – New alpha order.

class Divergences_jax.Renyi_Divergence_CC(discriminator, disc_optimizer, alpha, epochs, batch_size, final_act_func, discriminator_penalty=None)[source]

Bases: Renyi_Divergence

Renyi divergence class based on the convex-conjugate variational formula. R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of Renyi divergence using the convex-conjugate variational formula.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of Renyi divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of Renyi divergence for the generator using the convex-conjugate variational formula.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

final_layer_activation(y)[source]

Final layer activation function to enforce positive values.

Parameters:

y – Output of the discriminator.

Returns:

Activated output, ensuring positivity based on the final activation function.

class Divergences_jax.Renyi_Divergence_CC_rescaled(discriminator, disc_optimizer, alpha, epochs, batch_size, final_act_func, discriminator_penalty=None)[source]

Bases: Renyi_Divergence_CC

Rescaled Renyi divergence class based on the rescaled convex-conjugate variational formula. alpha * R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of the rescaled Renyi divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of rescaled Renyi divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for the generator of the rescaled Renyi divergence.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

final_layer_activation(y)[source]

Final layer activation function to enforce positivity, scaled by alpha.

Parameters:

y – Output of the discriminator.

Returns:

Activated output, scaled by the alpha parameter.

class Divergences_jax.Renyi_Divergence_DV(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]

Bases: Renyi_Divergence

Renyi divergence class based on the Renyi-Donsker-Varadhan variational formula. R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of Renyi divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of Renyi divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of Renyi divergence for the generator.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

class Divergences_jax.Renyi_Divergence_WCR(discriminator, disc_optimizer, alpha, epochs, batch_size, final_act_func, discriminator_penalty=None)[source]

Bases: Renyi_Divergence_CC

Rescaled Renyi divergence class as alpha approaches infinity (worst-case regret divergence). Dinfty(P||Q), x~P, y~Q.

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of the Renyi divergence class as alpha approaches infinity (worst-case regret divergence).

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of worst-case regret divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for the generator of the worst-case regret divergence.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

class Divergences_jax.alpha_Divergence_LT(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]

Bases: f_Divergence

Alpha-divergence class based on the Legendre transform. D_f_alpha(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f_alpha based on the alpha value.

Parameters:

y – Input to the Legendre transform.

Returns:

Transformed value.

get_order()[source]

Returns the order of the alpha-divergence.

Returns:

Alpha order.

set_order(alpha)[source]

Sets the order of the alpha-divergence.

Parameters:

alpha – New alpha order.

class Divergences_jax.f_Divergence(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Bases: Divergence

f-divergence class, parent class for f-divergence-based measures D_f(P||Q). Subclasses need to implement the Legendre transform of f (f_star).

eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula of f-divergence, D_f(P||Q).

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of divergence loss and updated variables.

eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]

Evaluates the variational formula for f-divergence when applied to a generator.

Parameters:
  • x – Generated samples.

  • params – Discriminator parameters.

  • vars – Additional discriminator variables.

  • labels – Optional input labels.

  • dropout_rng – Optional dropout key for stochasticity.

Returns:

Tuple of generator loss and updated variables.

f_star(y)[source]

Placeholder for the Legendre transform of the function f. Should be implemented by subclasses.

Parameters:

y – Input to the Legendre transform.

Returns:

None.

final_layer_activation(y)[source]

Final activation function applied to the output of the discriminator.

Parameters:

y – Output of the discriminator.

Returns:

Activated output.

class Divergences_jax.squared_Hellinger_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]

Bases: f_Divergence

Squared Hellinger distance class based on the Legendre transform. H(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = (sqrt(y) - 1)^2.

Parameters:

y – Input to the Legendre transform.

Returns:

Transformed value.

final_layer_activation(y)[source]

Final layer activation for squared Hellinger distance.

Parameters:

y – Input to the activation function.

Returns:

Activated output.