Divergences_jax module
- class Divergences_jax.DataLoader(data, batch_size, shuffle=True)[source]
Bases:
objectDataLoader class for loading and batching data during training.
- class Divergences_jax.Discriminator_Penalty(penalty_weight)[source]
Bases:
objectBase class for implementing penalties on the discriminator during training. Enables the implementation of discriminator constraints to regularize the divergence objective.
- evaluate(discriminator, x, y, params, batch_stats, key, labels=None, dropout_rng=None)[source]
Evaluates the penalty term. Should be overridden by subclasses.
- Parameters:
discriminator – Discriminator model.
x – Samples from distribution P.
y – Samples from distribution Q.
params – Discriminator parameters.
batch_stats – Additional statistics for the batch.
key – Random key for JAX RNG.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
None. (Subclasses should implement specific penalty evaluations.)
- class Divergences_jax.Divergence(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
objectBase class for Divergence measures D(P||Q) between random variables x~P, y~Q. This parent class defines common parameters and functions for different divergence measures.
- discriminate(x, params, vars, labels=None, dropout_rng=None)[source]
Discriminates between samples from x~P and y~Q using the discriminator model.
- Parameters:
x – Input data to be discriminated.
params – Parameters of the discriminator model.
vars – Additional variables such as batch statistics.
labels – Optional labels for the input data.
dropout_rng – Optional dropout key for stochasticity in dropout layers.
- Returns:
Tuple of the discriminator output and optional updated batch statistics.
- discriminator_loss(x, y, params, vars, labels=None, dropout_rng=None)[source]
Computes the loss for the discriminator model.
- Parameters:
x – Samples from P.
y – Samples from Q.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of discriminator loss and updated variables.
- estimate(x, y, params, vars, labels=None, dropout_rng=None)[source]
Estimates the divergence between P and Q.
- Parameters:
x – Samples from P.
y – Samples from Q.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of estimated divergence and updated variables.
- eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]
Placeholder method for evaluating the variational formula of a specific divergence. Should be overridden by subclasses.
- Parameters:
x – Samples from P.
y – Samples from Q.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
None.
- gen_train_step(gen_state, disc_state, disc_vars, gen_vars, key, z, labels=None, dropout_rng=None)[source]
Performs a single training step for the generator.
- Parameters:
gen_state – Generator optimizer state.
disc_state – Discriminator optimizer state.
disc_vars – Discriminator variables.
gen_vars – Generator variables.
key – Random key for JAX RNG.
z – Latent input to the generator.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Updated generator state and generator loss.
- generator_loss(x, params, vars, labels=None, dropout_rng=None)[source]
Computes the loss for the generator model.
- Parameters:
x – Generated samples.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of generator loss and updated variables.
- set_discriminator(discriminator)[source]
Sets a new discriminator model.
- Parameters:
discriminator – New discriminator model.
- set_no_epochs(epochs)[source]
Sets the number of training epochs.
- Parameters:
epochs – New number of epochs.
- train(data_P, data_Q, state, vars, save_estimates=True, labels=None, dropout_rng=None)[source]
Trains the model for a given number of epochs.
- Parameters:
data_P – Data samples from distribution P.
data_Q – Data samples from distribution Q.
state – Discriminator optimizer state.
vars – Discriminator variables.
save_estimates – Whether to save divergence estimates.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of estimated divergences and losses for each epoch.
- train_step(x, y, state, vars, key, labels=None, dropout_rng=None)[source]
Performs a single training step for the discriminator.
- Parameters:
x – Samples from P.
y – Samples from Q.
state – Optimizer state.
vars – Additional discriminator variables.
key – Random key for JAX RNG.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Updated state and loss value for the current step.
- class Divergences_jax.Gradient_Penalty_1Sided(penalty_weight, Lip_const)[source]
Bases:
Discriminator_PenaltyOne-sided gradient penalty to enforce a constraint: Lipschitz constant <= Lip_const.
- evaluate(discriminator, x, y, params, batch_stats, key, labels=None, dropout_rng=None)[source]
Computes the one-sided gradient penalty to enforce the Lipschitz constant constraint.
- Parameters:
discriminator – Discriminator model.
x – Samples from distribution P.
y – Samples from distribution Q.
params – Discriminator parameters.
batch_stats – Additional statistics for the batch.
key – Random key for JAX RNG.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
One-sided gradient penalty value.
- class Divergences_jax.Gradient_Penalty_2Sided(penalty_weight, Lip_const)[source]
Bases:
Discriminator_PenaltyTwo-sided gradient penalty to enforce a constraint: Lipschitz constant = Lip_const.
- evaluate(discriminator, x, y, params, labels=None, dropout_rng=None)[source]
Computes the two-sided gradient penalty to enforce the Lipschitz constant constraint.
- Parameters:
discriminator – Discriminator model.
x – Samples from distribution P.
y – Samples from distribution Q.
params – Discriminator parameters.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Two-sided gradient penalty value.
- class Divergences_jax.IPM(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
DivergenceIPM (Integral Probability Metrics) class, a subclass of Divergence. Evaluates the IPM between distributions P and Q using a variational formula.
- eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula for IPM.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of divergence loss and updated variables.
- eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula for IPM when applied to a generator model.
- Parameters:
x – Generated samples.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of generator loss and updated variables.
- class Divergences_jax.Jensen_Shannon_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
f_DivergenceJensen-Shannon divergence class based on the Legendre transform. JS(P||Q), x~P, y~Q.
- class Divergences_jax.KLD_DV(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
DivergenceKL Divergence class based on the Donsker-Varadhan variational formula. KL(P||Q), x~P, y~Q.
- eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula for KL divergence.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of KL divergence loss and updated variables.
- eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula for KL divergence applied to a generator.
- Parameters:
x – Generated samples.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of generator loss and updated variables.
- class Divergences_jax.KLD_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
f_DivergenceKullback-Leibler (KL) Divergence class based on the Legendre transform. KL(P||Q), x~P, y~Q.
- class Divergences_jax.Pearson_chi_squared_HCR(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
DivergencePearson chi-squared divergence class based on the Hammersley-Chapman-Robbins bound. chi^2(P||Q), x~P, y~Q.
- eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula for Pearson chi-squared divergence based on the Hammersley-Chapman-Robbins bound.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Pearson chi-squared divergence loss.
- class Divergences_jax.Pearson_chi_squared_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
f_DivergencePearson chi-squared divergence class based on the Legendre transform. chi^2(P||Q), x~P, y~Q.
- class Divergences_jax.Renyi_Divergence(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
DivergenceRenyi divergence class, a subclass of Divergence. R_alpha(P||Q), x~P, y~Q.
- class Divergences_jax.Renyi_Divergence_CC(discriminator, disc_optimizer, alpha, epochs, batch_size, final_act_func, discriminator_penalty=None)[source]
Bases:
Renyi_DivergenceRenyi divergence class based on the convex-conjugate variational formula. R_alpha(P||Q), x~P, y~Q.
- eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula of Renyi divergence using the convex-conjugate variational formula.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of Renyi divergence loss and updated variables.
- eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula of Renyi divergence for the generator using the convex-conjugate variational formula.
- Parameters:
x – Generated samples.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of generator loss and updated variables.
- class Divergences_jax.Renyi_Divergence_CC_rescaled(discriminator, disc_optimizer, alpha, epochs, batch_size, final_act_func, discriminator_penalty=None)[source]
Bases:
Renyi_Divergence_CCRescaled Renyi divergence class based on the rescaled convex-conjugate variational formula. alpha * R_alpha(P||Q), x~P, y~Q.
- eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula of the rescaled Renyi divergence.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of rescaled Renyi divergence loss and updated variables.
- eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula for the generator of the rescaled Renyi divergence.
- Parameters:
x – Generated samples.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of generator loss and updated variables.
- class Divergences_jax.Renyi_Divergence_DV(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
Renyi_DivergenceRenyi divergence class based on the Renyi-Donsker-Varadhan variational formula. R_alpha(P||Q), x~P, y~Q.
- eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula of Renyi divergence.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of Renyi divergence loss and updated variables.
- eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula of Renyi divergence for the generator.
- Parameters:
x – Generated samples.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of generator loss and updated variables.
- class Divergences_jax.Renyi_Divergence_WCR(discriminator, disc_optimizer, alpha, epochs, batch_size, final_act_func, discriminator_penalty=None)[source]
Bases:
Renyi_Divergence_CCRescaled Renyi divergence class as alpha approaches infinity (worst-case regret divergence). Dinfty(P||Q), x~P, y~Q.
- eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula of the Renyi divergence class as alpha approaches infinity (worst-case regret divergence).
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of worst-case regret divergence loss and updated variables.
- eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula for the generator of the worst-case regret divergence.
- Parameters:
x – Generated samples.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of generator loss and updated variables.
- class Divergences_jax.alpha_Divergence_LT(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
f_DivergenceAlpha-divergence class based on the Legendre transform. D_f_alpha(P||Q), x~P, y~Q.
- class Divergences_jax.f_Divergence(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
Divergencef-divergence class, parent class for f-divergence-based measures D_f(P||Q). Subclasses need to implement the Legendre transform of f (f_star).
- eval_var_formula(x, y, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula of f-divergence, D_f(P||Q).
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of divergence loss and updated variables.
- eval_var_formula_gen(x, params, vars, labels=None, dropout_rng=None)[source]
Evaluates the variational formula for f-divergence when applied to a generator.
- Parameters:
x – Generated samples.
params – Discriminator parameters.
vars – Additional discriminator variables.
labels – Optional input labels.
dropout_rng – Optional dropout key for stochasticity.
- Returns:
Tuple of generator loss and updated variables.
- class Divergences_jax.squared_Hellinger_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
f_DivergenceSquared Hellinger distance class based on the Legendre transform. H(P||Q), x~P, y~Q.