GAN_CIFAR10_tf module
- class GAN_CIFAR10_tf.Discriminator(*args, **kwargs)[source]
Bases:
ModelDiscriminator model, which evaluates whether inputs are real or fake.
- class GAN_CIFAR10_tf.DiscriminatorBlock(*args, **kwargs)[source]
Bases:
LayerResNet-style block for the discriminator model, with optional downsampling.
- Parameters:
in_chans (int) – Number of input channels.
out_chans (int) – Number of output channels.
downsample (bool) – Whether to apply 2x downsampling.
first (bool) – Whether this is the first block in the discriminator.
- class GAN_CIFAR10_tf.Generator(*args, **kwargs)[source]
Bases:
ModelGenerator model that consists of a dense layer followed by multiple generator blocks.
- class GAN_CIFAR10_tf.GeneratorBlock(*args, **kwargs)[source]
Bases:
LayerResNet-style block for the generator model, with optional upsampling.
- Parameters:
in_chans (int) – Number of input channels.
out_chans (int) – Number of output channels.
upsample (bool) – Whether to apply 2x upsampling.
- call(inputs)[source]
Forward pass of the generator block.
- Parameters:
inputs (tf.Tensor) – Input tensor.
- Returns:
Output tensor after passing through the block.
- Return type:
tf.Tensor
- compute_output_shape(input_shape)[source]
Compute output shape of the block. :param input_shape: The shape of the input tensor.
- Returns:
Tuple of the output shape.