Divergences_tf module
- class Divergences_tf.Discriminator_Penalty(penalty_weight)[source]
Bases:
objectDiscriminator penalty class penalizes the divergence objective functional during training. Allows for the (approximate) implementation of discriminator constraints.
- evaluate(discriminator, x, y, labels=None)[source]
Evaluates the penalty term. Should be overridden by subclasses.
- Parameters:
discriminator – Discriminator model.
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
None. (Subclasses should implement penalty evaluation.)
- class Divergences_tf.Divergence(*args, **kwargs)[source]
Bases:
ModelDivergence D(P||Q) between random variables x~P, y~Q. Parent class where common parameters and functions are defined.
- discriminate(x, labels=None)[source]
Discriminates between samples from distributions P and Q.
- Parameters:
x – Input data to discriminate.
labels – Optional labels for the input data.
- Returns:
Discriminator output.
- discriminator_loss(x, y, labels=None)[source]
Computes the discriminator loss.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Discriminator loss.
- estimate(x, y, labels=None)[source]
Estimates the divergence measure.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Estimated divergence loss.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for the divergence measure. Should be implemented by subclasses.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
None (to be overridden by subclasses).
- generator_loss(x, labels=None)[source]
Computes the generator loss.
- Parameters:
x – Generated samples.
labels – Optional labels for the input data.
- Returns:
Generator loss.
- train(data_P, data_Q, labels=None, save_estimates=True)[source]
Trains the model for a number of epochs.
- Parameters:
data_P – Data samples from distribution P.
data_Q – Data samples from distribution Q.
labels – Optional labels for the input data.
save_estimates – Whether to save divergence estimates.
- Returns:
A list of divergence estimates for each epoch.
- class Divergences_tf.Gradient_Penalty_1Sided(penalty_weight, Lip_const)[source]
Bases:
Discriminator_PenaltyOne-sided gradient penalty class to enforce the Lipschitz constant <= Lip_const.
- evaluate(discriminator, x, y, labels=None)[source]
Computes the one-sided gradient penalty to enforce the Lipschitz constant <= Lip_const.
- Parameters:
discriminator – Discriminator model.
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
One-sided gradient penalty value.
- class Divergences_tf.Gradient_Penalty_2Sided(penalty_weight, Lip_const)[source]
Bases:
Discriminator_PenaltyTwo-sided gradient penalty class to enforce the Lipschitz constant = Lip_const.
- evaluate(discriminator, x, y, labels=None)[source]
Computes the two-sided gradient penalty to enforce the Lipschitz constant = Lip_const.
- Parameters:
discriminator – Discriminator model.
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Two-sided gradient penalty value.
- class Divergences_tf.IPM(*args, **kwargs)[source]
Bases:
DivergenceIntegral Probability Metric (IPM) class for evaluating IPMs.
- class Divergences_tf.Jensen_Shannon_LT(*args, **kwargs)[source]
Bases:
f_DivergenceJensen-Shannon divergence class based on the Legendre transform. JS(P||Q), x~P, y~Q.
- class Divergences_tf.KLD_DV(*args, **kwargs)[source]
Bases:
DivergenceKL divergence class based on the Donsker-Varadhan variational formula. KL(P||Q), x~P, y~Q.
- class Divergences_tf.KLD_LT(*args, **kwargs)[source]
Bases:
f_DivergenceKullback-Leibler (KL) divergence class based on the Legendre transform. KL(P||Q), x~P, y~Q.
- class Divergences_tf.Pearson_chi_squared_HCR(*args, **kwargs)[source]
Bases:
DivergencePearson chi-squared divergence class based on the Hammersley-Chapman-Robbins bound. chi^2(P||Q), x~P, y~Q.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for Pearson chi-squared divergence based on the Hammersley-Chapman-Robbins bound.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Pearson chi-squared divergence loss.
- class Divergences_tf.Pearson_chi_squared_LT(*args, **kwargs)[source]
Bases:
f_DivergencePearson chi-squared divergence class based on the Legendre transform. chi^2(P||Q), x~P, y~Q.
- class Divergences_tf.Renyi_Divergence(*args, **kwargs)[source]
Bases:
DivergenceRenyi divergence class for computing Renyi divergence R_alpha(P||Q), x~P, y~Q.
- class Divergences_tf.Renyi_Divergence_CC(*args, **kwargs)[source]
Bases:
Renyi_DivergenceRenyi divergence class based on the convex-conjugate variational formula. R_alpha(P||Q), x~P, y~Q.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for Renyi divergence based on the convex-conjugate formula.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Renyi divergence loss.
- class Divergences_tf.Renyi_Divergence_CC_rescaled(*args, **kwargs)[source]
Bases:
Renyi_Divergence_CCRescaled Renyi divergence class based on the rescaled convex-conjugate variational formula. alpha * R_alpha(P||Q), x~P, y~Q.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for the rescaled Renyi divergence.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Rescaled Renyi divergence loss.
- class Divergences_tf.Renyi_Divergence_DV(*args, **kwargs)[source]
Bases:
Renyi_DivergenceRenyi divergence class based on the Renyi-Donsker-Varadhan variational formula. R_alpha(P||Q), x~P, y~Q.
- class Divergences_tf.Renyi_Divergence_WCR(*args, **kwargs)[source]
Bases:
Renyi_Divergence_CCRenyi divergence class as alpha approaches infinity (worst-case regret divergence). Dinfty(P||Q), where x ~ P and y ~ Q.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for worst-case regret divergence as alpha approaches infinity.
- Parameters:
x (tf.Tensor) – Samples from distribution P.
y (tf.Tensor) – Samples from distribution Q.
labels (tf.Tensor, optional) – Optional labels for the input data.
- Returns:
Worst-case regret divergence loss.
- Return type:
tf.Tensor
- eval_var_formula_gen(X, labels=None)[source]
Evaluates the generator’s objective based on worst-case regret divergence.
- Parameters:
X (tf.Tensor) – Samples from distribution Q.
labels (tf.Tensor, optional) – Optional labels for the input data.
- Returns:
Generator’s loss based on worst-case regret divergence.
- Return type:
tf.Tensor
- class Divergences_tf.alpha_Divergence_LT(*args, **kwargs)[source]
Bases:
f_DivergenceAlpha-divergence class based on the Legendre transform. D_{f_alpha}(P||Q), x~P, y~Q.
- class Divergences_tf.f_Divergence(*args, **kwargs)[source]
Bases:
Divergencef-divergence class, parent class for f-divergence-based measures D_f(P||Q). Subclasses must implement the Legendre transform f_star.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for f-divergence.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
f-divergence loss.
- eval_var_formula_gen(x, labels=None)[source]
Evaluates the variational formula for f-divergence applied to the generator.
- Parameters:
x – Generated samples.
labels – Optional labels for the input data.
- Returns:
Generator loss based on f-divergence.
- class Divergences_tf.squared_Hellinger_LT(*args, **kwargs)[source]
Bases:
f_DivergenceSquared Hellinger distance class based on the Legendre transform. H(P||Q), x~P, y~Q.