GAN_torch module

class GAN_torch.GAN(divergence, generator, gen_optimizer, noise_source, epochs, disc_steps_per_gen_step, batch_size=None, reverse_order=False, include_penalty_in_gen_loss=False)[source]

Bases: object

Class for training a GAN using one of the provided divergences If reverse_order=False the GAN works to minimize min_theta D(P||g_theta(Z)) where P is the distribution to be leared, Z is the noise source and g_theta is the generator (with parameters theta). If reverse_order=True the GAN works to minimize min_theta D(g_theta(Z)||P) where P is the distribution to be leared, Z is the noise source and g_theta is the generator (with parameters theta).

disc_train_step(x, z)[source]

discriminator’s parameters update

estimate_loss(x, z)[source]

Estimating the loss

gen_train_step(x, z)[source]

generator’s parameters update

generate_samples(N_samples, device='cpu')[source]
train(data_P, save_frequency=None, num_gen_samples_to_save=None, save_loss_estimates=False)[source]

training function of our GAN