Divergences_tf module

class Divergences_tf.Discriminator_Penalty(penalty_weight)[source]

Bases: object

Discriminator penalty class penalizes the divergence objective functional during training. Allows for the (approximate) implementation of discriminator constraints.

evaluate(discriminator, x, y, labels=None)[source]

Evaluates the penalty term. Should be overridden by subclasses.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

None. (Subclasses should implement penalty evaluation.)

get_penalty_weight()[source]

Returns the weight of the penalty term.

set_penalty_weight(weight)[source]

Sets the weight of the penalty term.

class Divergences_tf.Divergence(*args, **kwargs)[source]

Bases: Model

Divergence D(P||Q) between random variables x~P, y~Q. Parent class where common parameters and functions are defined.

discriminate(x, labels=None)[source]

Discriminates between samples from distributions P and Q.

Parameters:
  • x – Input data to discriminate.

  • labels – Optional labels for the input data.

Returns:

Discriminator output.

discriminator_loss(x, y, labels=None)[source]

Computes the discriminator loss.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Discriminator loss.

estimate(x, y, labels=None)[source]

Estimates the divergence measure.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Estimated divergence loss.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for the divergence measure. Should be implemented by subclasses.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

None (to be overridden by subclasses).

generator_loss(x, labels=None)[source]

Computes the generator loss.

Parameters:
  • x – Generated samples.

  • labels – Optional labels for the input data.

Returns:

Generator loss.

get_batch_size()[source]

Returns the batch size.

get_discriminator()[source]

Returns the discriminator model.

get_learning_rate()[source]

Returns the learning rate.

get_no_epochs()[source]

Returns the number of training epochs.

set_batch_size(BATCH_SIZE)[source]

Sets the batch size.

set_discriminator(discriminator)[source]

Sets a new discriminator model.

set_learning_rate(lr)[source]

Sets the learning rate.

set_no_epochs(epochs)[source]

Sets the number of training epochs.

train(data_P, data_Q, labels=None, save_estimates=True)[source]

Trains the model for a number of epochs.

Parameters:
  • data_P – Data samples from distribution P.

  • data_Q – Data samples from distribution Q.

  • labels – Optional labels for the input data.

  • save_estimates – Whether to save divergence estimates.

Returns:

A list of divergence estimates for each epoch.

train_step(x, y, labels=None)[source]

Performs a training step for the discriminator.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Loss value for the current step.

class Divergences_tf.Gradient_Penalty_1Sided(penalty_weight, Lip_const)[source]

Bases: Discriminator_Penalty

One-sided gradient penalty class to enforce the Lipschitz constant <= Lip_const.

evaluate(discriminator, x, y, labels=None)[source]

Computes the one-sided gradient penalty to enforce the Lipschitz constant <= Lip_const.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

One-sided gradient penalty value.

get_Lip_constant()[source]

Returns the target Lipschitz constant.

set_Lip_constant(L)[source]

Sets the target Lipschitz constant.

class Divergences_tf.Gradient_Penalty_2Sided(penalty_weight, Lip_const)[source]

Bases: Discriminator_Penalty

Two-sided gradient penalty class to enforce the Lipschitz constant = Lip_const.

evaluate(discriminator, x, y, labels=None)[source]

Computes the two-sided gradient penalty to enforce the Lipschitz constant = Lip_const.

Parameters:
  • discriminator – Discriminator model.

  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Two-sided gradient penalty value.

get_Lip_constant()[source]

Returns the target Lipschitz constant.

set_Lip_constant(L)[source]

Sets the target Lipschitz constant.

class Divergences_tf.IPM(*args, **kwargs)[source]

Bases: Divergence

Integral Probability Metric (IPM) class for evaluating IPMs.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for IPM.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

IPM loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the variational formula for IPM applied to the generator.

Parameters:
  • x – Generated samples.

  • labels – Optional labels for the input data.

Returns:

Generator loss based on IPM.

class Divergences_tf.Jensen_Shannon_LT(*args, **kwargs)[source]

Bases: f_Divergence

Jensen-Shannon divergence class based on the Legendre transform. JS(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = y * log(y) - (y + 1) * log((y + 1) / 2).

final_layer_activation(y)[source]

Applies the final layer activation for Jensen-Shannon divergence.

class Divergences_tf.KLD_DV(*args, **kwargs)[source]

Bases: Divergence

KL divergence class based on the Donsker-Varadhan variational formula. KL(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for KL divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

KL divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on KL divergence.

class Divergences_tf.KLD_LT(*args, **kwargs)[source]

Bases: f_Divergence

Kullback-Leibler (KL) divergence class based on the Legendre transform. KL(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = y * log(y).

class Divergences_tf.Pearson_chi_squared_HCR(*args, **kwargs)[source]

Bases: Divergence

Pearson chi-squared divergence class based on the Hammersley-Chapman-Robbins bound. chi^2(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for Pearson chi-squared divergence based on the Hammersley-Chapman-Robbins bound.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Pearson chi-squared divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on Pearson chi-squared divergence.

class Divergences_tf.Pearson_chi_squared_LT(*args, **kwargs)[source]

Bases: f_Divergence

Pearson chi-squared divergence class based on the Legendre transform. chi^2(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = (y - 1)^2.

class Divergences_tf.Renyi_Divergence(*args, **kwargs)[source]

Bases: Divergence

Renyi divergence class for computing Renyi divergence R_alpha(P||Q), x~P, y~Q.

get_order()[source]

Returns the order of the Renyi divergence.

set_order(alpha)[source]

Sets the order of the Renyi divergence.

class Divergences_tf.Renyi_Divergence_CC(*args, **kwargs)[source]

Bases: Renyi_Divergence

Renyi divergence class based on the convex-conjugate variational formula. R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for Renyi divergence based on the convex-conjugate formula.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Renyi divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on Renyi divergence using the convex-conjugate formula.

final_layer_activation(y)[source]

Applies the final layer activation to enforce positive values.

class Divergences_tf.Renyi_Divergence_CC_rescaled(*args, **kwargs)[source]

Bases: Renyi_Divergence_CC

Rescaled Renyi divergence class based on the rescaled convex-conjugate variational formula. alpha * R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for the rescaled Renyi divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Rescaled Renyi divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on rescaled Renyi divergence.

final_layer_activation(y)[source]

Applies the final layer activation and scales it by alpha.

class Divergences_tf.Renyi_Divergence_DV(*args, **kwargs)[source]

Bases: Renyi_Divergence

Renyi divergence class based on the Renyi-Donsker-Varadhan variational formula. R_alpha(P||Q), x~P, y~Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for Renyi divergence based on the Renyi-Donsker-Varadhan formula.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

Renyi divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the generator’s objective based on Renyi divergence.

class Divergences_tf.Renyi_Divergence_WCR(*args, **kwargs)[source]

Bases: Renyi_Divergence_CC

Renyi divergence class as alpha approaches infinity (worst-case regret divergence). Dinfty(P||Q), where x ~ P and y ~ Q.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for worst-case regret divergence as alpha approaches infinity.

Parameters:
  • x (tf.Tensor) – Samples from distribution P.

  • y (tf.Tensor) – Samples from distribution Q.

  • labels (tf.Tensor, optional) – Optional labels for the input data.

Returns:

Worst-case regret divergence loss.

Return type:

tf.Tensor

eval_var_formula_gen(X, labels=None)[source]

Evaluates the generator’s objective based on worst-case regret divergence.

Parameters:
  • X (tf.Tensor) – Samples from distribution Q.

  • labels (tf.Tensor, optional) – Optional labels for the input data.

Returns:

Generator’s loss based on worst-case regret divergence.

Return type:

tf.Tensor

class Divergences_tf.alpha_Divergence_LT(*args, **kwargs)[source]

Bases: f_Divergence

Alpha-divergence class based on the Legendre transform. D_{f_alpha}(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f_alpha based on the alpha value.

get_order()[source]

Returns the order of the alpha-divergence.

set_order(alpha)[source]

Sets the order of the alpha-divergence.

class Divergences_tf.f_Divergence(*args, **kwargs)[source]

Bases: Divergence

f-divergence class, parent class for f-divergence-based measures D_f(P||Q). Subclasses must implement the Legendre transform f_star.

eval_var_formula(x, y, labels=None)[source]

Evaluates the variational formula for f-divergence.

Parameters:
  • x – Samples from distribution P.

  • y – Samples from distribution Q.

  • labels – Optional labels for the input data.

Returns:

f-divergence loss.

eval_var_formula_gen(x, labels=None)[source]

Evaluates the variational formula for f-divergence applied to the generator.

Parameters:
  • x – Generated samples.

  • labels – Optional labels for the input data.

Returns:

Generator loss based on f-divergence.

f_star(y)[source]

Placeholder for the Legendre transform of f. Should be implemented by subclasses.

final_layer_activation(y)[source]

Applies the final layer activation.

class Divergences_tf.squared_Hellinger_LT(*args, **kwargs)[source]

Bases: f_Divergence

Squared Hellinger distance class based on the Legendre transform. H(P||Q), x~P, y~Q.

f_star(y)[source]

Legendre transform of f(y) = (sqrt(y) - 1)^2.

final_layer_activation(y)[source]

Applies the final layer activation for squared Hellinger distance.