Divergences_torch module
- class Divergences_torch.Discriminator_Penalty(penalty_weight)[source]
Bases:
objectDiscriminator penalty class penalizes the divergence objective functional during training. Allows for the (approximate) implementation of discriminator constraints.
- evaluate(discriminator, x, y, labels=None)[source]
Evaluates the penalty term. Should be overridden by subclasses.
- Parameters:
discriminator – Discriminator model.
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
None. (Subclasses should implement penalty evaluation.)
- class Divergences_torch.Divergence(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
ModuleDivergence D(P||Q) between random variables x~P, y~Q. Parent class where common parameters and functions are defined.
- discriminate(x, labels=None)[source]
Discriminates between samples from distributions P and Q.
- Parameters:
x – Input data to discriminate.
labels – Optional labels for the input data.
- Returns:
Discriminator output.
- discriminator_loss(x, y, labels=None)[source]
Computes the discriminator loss.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Discriminator loss.
- estimate(x, y, labels=None)[source]
Estimates the divergence measure.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Estimated divergence loss.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for the divergence measure. Should be implemented by subclasses.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
None (to be overridden by subclasses).
- generator_loss(x, labels=None)[source]
Computes the generator loss.
- Parameters:
x – Generated samples.
labels – Optional labels for the input data.
- Returns:
Generator loss.
- train(data_P, data_Q, labels=None, save_estimates=True)[source]
Trains the model for a number of epochs.
- Parameters:
data_P – Data samples from distribution P.
data_Q – Data samples from distribution Q.
labels – Optional labels for the input data.
save_estimates – Whether to save divergence estimates.
- Returns:
A list of divergence estimates for each epoch.
- class Divergences_torch.Gradient_Penalty_1Sided(penalty_weight, Lip_const)[source]
Bases:
Discriminator_PenaltyOne-sided gradient penalty class to enforce the Lipschitz constant <= Lip_const.
- evaluate(discriminator, x, y, labels=None)[source]
Computes the one-sided gradient penalty to enforce the Lipschitz constant <= Lip_const.
- Parameters:
discriminator – Discriminator model.
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
One-sided gradient penalty value.
- class Divergences_torch.Gradient_Penalty_2Sided(penalty_weight, Lip_const)[source]
Bases:
Discriminator_PenaltyTwo-sided gradient penalty class to enforce the Lipschitz constant = Lip_const.
- evaluate(discriminator, x, y, labels=None)[source]
Computes the two-sided gradient penalty to enforce the Lipschitz constant = Lip_const.
- Parameters:
discriminator – Discriminator model.
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Two-sided gradient penalty value.
- class Divergences_torch.IPM(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
DivergenceIPM class for evaluating Integral Probability Metrics (IPM).
- class Divergences_torch.Jensen_Shannon_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
f_DivergenceJensen-Shannon divergence class based on the Legendre transform. JS(P||Q), x~P, y~Q.
- class Divergences_torch.KLD_DV(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
DivergenceKL divergence class based on the Donsker-Varadhan variational formula. KL(P||Q), x~P, y~Q.
- class Divergences_torch.KLD_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
f_DivergenceKullback-Leibler (KL) divergence class based on the Legendre transform. KL(P||Q), x~P, y~Q.
- class Divergences_torch.Pearson_chi_squared_HCR(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
DivergencePearson chi-squared divergence class based on the Hammersley-Chapman-Robbins bound. chi^2(P||Q), x~P, y~Q.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for Pearson chi-squared divergence based on the Hammersley-Chapman-Robbins bound.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Pearson chi-squared divergence loss.
- class Divergences_torch.Pearson_chi_squared_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
f_DivergencePearson chi-squared divergence class based on the Legendre transform. chi^2(P||Q), x~P, y~Q.
- class Divergences_torch.Renyi_Divergence(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
DivergenceRenyi divergence class for computing Renyi divergence R_alpha(P||Q), x~P, y~Q.
- class Divergences_torch.Renyi_Divergence_CC(discriminator, disc_optimizer, alpha, epochs, batch_size, fl_act_func, discriminator_penalty=None)[source]
Bases:
Renyi_DivergenceRenyi divergence class based on the convex-conjugate variational formula. R_alpha(P||Q), x~P, y~Q.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for Renyi divergence based on the convex-conjugate formula.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Renyi divergence loss.
- class Divergences_torch.Renyi_Divergence_CC_rescaled(discriminator, disc_optimizer, alpha, epochs, batch_size, fl_act_func, discriminator_penalty=None)[source]
Bases:
Renyi_Divergence_CCRescaled Renyi divergence class based on the rescaled convex-conjugate variational formula. alpha * R_alpha(P||Q), x~P, y~Q.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for the rescaled Renyi divergence.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Rescaled Renyi divergence loss.
- class Divergences_torch.Renyi_Divergence_DV(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
Renyi_DivergenceRenyi divergence class based on the Renyi-Donsker-Varadhan variational formula. R_alpha(P||Q), x~P, y~Q.
- class Divergences_torch.Renyi_Divergence_WCR(discriminator, disc_optimizer, alpha, epochs, batch_size, fl_act_func, discriminator_penalty=None)[source]
Bases:
Renyi_Divergence_CCRenyi divergence class as alpha approaches infinity (worst-case regret divergence). Dinfty(P||Q), x~P, y~Q.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for worst-case regret divergence as alpha approaches infinity.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
Worst-case regret divergence loss.
- class Divergences_torch.alpha_Divergence_LT(discriminator, disc_optimizer, alpha, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
f_DivergenceAlpha-divergence class based on the Legendre transform. D_{f_alpha}(P||Q), x~P, y~Q.
- class Divergences_torch.f_Divergence(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
Divergencef-divergence class, parent class for f-divergence-based measures D_f(P||Q). Subclasses must implement the Legendre transform f_star.
- eval_var_formula(x, y, labels=None)[source]
Evaluates the variational formula for f-divergence.
- Parameters:
x – Samples from distribution P.
y – Samples from distribution Q.
labels – Optional labels for the input data.
- Returns:
f-divergence loss.
- eval_var_formula_gen(x, labels=None)[source]
Evaluates the variational formula for f-divergence applied to the generator.
- Parameters:
x – Generated samples.
labels – Optional labels for the input data.
- Returns:
Generator loss based on f-divergence.
- class Divergences_torch.squared_Hellinger_LT(discriminator, disc_optimizer, epochs, batch_size, discriminator_penalty=None)[source]
Bases:
f_DivergenceSquared Hellinger distance class based on the Legendre transform. H(P||Q), x~P, y~Q.