src.utils.training_utils

  1import keras
  2import keras.ops as ops
  3import numpy as np
  4import matplotlib.pyplot as plt
  5from tqdm import tqdm
  6
  7from src.models.discriminator import *
  8from src.models.generator import *
  9from src.models.keypoint_detector import *
 10
 11def generate_identity_jacobians(batch, num_kp):
 12    """
 13    Generate a batch of identity Jacobian matrices for keypoints.
 14
 15    Each Jacobian is a 2x2 identity matrix, indicating no local transformation.
 16    The output is used in motion transfer models where identity Jacobians are
 17    assigned to keypoints that do not contribute to deformation.
 18
 19    Parameters:
 20        batch (int): Number of batches (B), typically corresponding to the batch size of input images.
 21        num_kp (int): Number of keypoints (n) per image.
 22
 23    Returns:
 24        tf.Tensor: A tensor of shape (B, n, 2, 2) containing identity matrices for each keypoint in the batch.
 25    """
 26    identity_matrix = ops.convert_to_tensor([[1, 0], [0, 1]], dtype="float32")  # (2, 2)
 27
 28    # Expand to (B, n, 2, 2)
 29    identity_jacobians = ops.expand_dims(identity_matrix, axis=0)  # (1, 2, 2)
 30    identity_jacobians = ops.repeat(identity_jacobians, batch * num_kp, axis=0)
 31    identity_jacobians = ops.reshape(identity_jacobians, (batch, num_kp, 2, 2))
 32
 33    return identity_jacobians
 34
 35
 36import numpy as np
 37import matplotlib.pyplot as plt
 38from tqdm import tqdm
 39from tensorflow import keras
 40
 41def train_motion_model(X, Y,
 42                       generator,
 43                       gan_model,
 44                       discriminator,
 45                       batch_size=16,
 46                       epochs=10,
 47                       save_path='./working',
 48                       preview=True):
 49    """
 50    Trains the generator and GAN model using the provided datasets.
 51
 52    Args:
 53        X (tuple): Tuple of (X0, X1) input arrays.
 54        Y (np.ndarray): Ground truth output images.
 55        generator (keras.Model): Generator model.
 56        gan_model (keras.Model): GAN model combining generator and discriminator.
 57        discriminator (keras.Model): Discriminator model.
 58        batch_size (int): Number of samples per training batch.
 59        epochs (int): Number of training epochs.
 60        save_path (str): Path to save model weights.
 61        preview (bool): Whether to show output images during training.
 62    """
 63    total_inst = (Y.shape[0] - 1) // batch_size * batch_size
 64    X0, X1 = X[0][:total_inst], X[1][:total_inst]
 65    Y = Y[:total_inst]
 66
 67    generator.layers[-1].batch_size = batch_size
 68
 69    for epoch in range(epochs):
 70        print(f"Epoch {epoch + 1}/{epochs}")
 71        for batch_number in tqdm(range(0, total_inst, batch_size), desc="Batches"):
 72            x_batch_0 = X0[batch_number: batch_number + batch_size]
 73            x_batch_1 = X1[batch_number: batch_number + batch_size]
 74            y_batch = Y[batch_number: batch_number + batch_size]
 75
 76            # Train generator directly (supervised)
 77            generator.train_on_batch((x_batch_0, x_batch_1), y_batch)
 78
 79            # Train discriminator to distinguish real vs fake
 80            loss_real = discriminator.train_on_batch(x_batch_1, keras.ops.zeros((batch_size, 1)))
 81            print('Discriminator Loss (real):', loss_real)
 82
 83            # Freeze discriminator for generator adversarial training
 84            gan_model.trainable = False
 85            gan_model.layers[-1].trainable = True
 86
 87            loss_gen = gan_model.train_on_batch((x_batch_0, x_batch_1), keras.ops.ones((batch_size, 1)))
 88            print('GAN Loss (G):', loss_gen)
 89
 90            # Re-enable discriminator for adversarial step
 91            gan_model.trainable = True
 92            gan_model.layers[-1].trainable = False
 93
 94            loss_fake = gan_model.train_on_batch((x_batch_0, x_batch_1), keras.ops.zeros((batch_size, 1)))
 95            print('GAN Loss (D):', loss_fake)
 96
 97        # Visualization after each epoch
 98        if preview:
 99            pred = generator((x_batch_0, x_batch_1))[0, ..., 0]
100            plt.imshow(pred, cmap='gray')
101            plt.title(f"Epoch {epoch + 1}")
102            plt.axis('off')
103            plt.show()
104
105        # Save model weights
106        generator.save_weights(f'{save_path}/generator.weights.h5')
107        gan_model.save_weights(f'{save_path}/GAN.weights.h5')
108
109    print("Training complete.")
110
111
112def setup_keypoint_pipeline(
113    keypoint_detector,
114    generator,
115    discriminator_model,
116    image_size=(256, 256, 1),
117    batch_size=16,
118    warmup_samples=500,
119    warmup_epochs=10,
120    training_epochs=250,
121    num_keypoints=10,
122    learning_rate=1e-4,
123):
124    """
125    Sets up a general training pipeline for keypoint-based image generation using GAN.
126
127    Returns:
128        - GAN model
129        - Generator model
130        - Keypoint detector
131        - Discriminator
132        - Aligner (warmup model)
133    """
134
135
136    # ------------------------------
137    #   Warmup keypoint detector (align jacobians)
138    # ------------------------------
139    kp_input = keras.Input(shape=image_size)
140    kp_output = keypoint_detector(kp_input)
141    kp_aligner = keras.Model(inputs=kp_input, outputs=kp_output[1])
142    kp_aligner.compile(optimizer='adam', loss='mse')
143
144    # Dummy warmup training
145    kp_aligner.fit(
146        keras.random.normal((warmup_samples, *image_size)),
147        generate_identity_jacobians(warmup_samples, num_keypoints),
148        batch_size=50,
149        epochs=warmup_epochs
150    )
151
152    # ------------------------------
153    #   Set up GAN pipeline
154    # ------------------------------
155    src_input = keras.Input(shape=image_size)
156    drv_input = keras.Input(shape=image_size)
157
158    src_kp = keypoint_detector(src_input)
159    drv_kp = keypoint_detector(drv_input)
160
161    gen_out = generator((src_input, src_kp[0], src_kp[1], drv_kp[0], drv_kp[1]))
162
163    generator_model = keras.Model(inputs=[src_input, drv_input], outputs=gen_out)
164    generator_model.compile(
165        optimizer=keras.optimizers.Adam(learning_rate),
166        loss='mse',
167        run_eagerly=False
168    )
169
170    discriminator_model.compile(
171        optimizer=keras.optimizers.Adam(learning_rate),
172        loss='binary_crossentropy',
173        run_eagerly=False
174    )
175
176    # GAN pipeline with frozen discriminator
177    disc_out = discriminator_model(generator((src_input, src_kp[0], src_kp[1], drv_kp[0], drv_kp[1])))
178    gan_model = keras.Model(inputs=[src_input, drv_input], outputs=disc_out)
179    gan_model.compile(
180        optimizer=keras.optimizers.Adam(learning_rate),
181        loss='binary_crossentropy',
182        run_eagerly=False
183    )
184
185    # Debugging summaries
186    print("\n🧱 GAN Summary:")
187    gan_model.summary()
188    print("\n🧱 Generator Backbone Summary:")
189    generator.model.summary()
190    print("\n🧱 Generator Upscaler Summary:")
191    generator.upscaler.summary()
192
193    return {
194        "gan": gan_model,
195        "generator_model": generator_model,
196        "keypoint_detector": keypoint_detector,
197        "discriminator": discriminator_model,
198        "aligner": kp_aligner
199    }
def generate_identity_jacobians(batch, num_kp):
12def generate_identity_jacobians(batch, num_kp):
13    """
14    Generate a batch of identity Jacobian matrices for keypoints.
15
16    Each Jacobian is a 2x2 identity matrix, indicating no local transformation.
17    The output is used in motion transfer models where identity Jacobians are
18    assigned to keypoints that do not contribute to deformation.
19
20    Parameters:
21        batch (int): Number of batches (B), typically corresponding to the batch size of input images.
22        num_kp (int): Number of keypoints (n) per image.
23
24    Returns:
25        tf.Tensor: A tensor of shape (B, n, 2, 2) containing identity matrices for each keypoint in the batch.
26    """
27    identity_matrix = ops.convert_to_tensor([[1, 0], [0, 1]], dtype="float32")  # (2, 2)
28
29    # Expand to (B, n, 2, 2)
30    identity_jacobians = ops.expand_dims(identity_matrix, axis=0)  # (1, 2, 2)
31    identity_jacobians = ops.repeat(identity_jacobians, batch * num_kp, axis=0)
32    identity_jacobians = ops.reshape(identity_jacobians, (batch, num_kp, 2, 2))
33
34    return identity_jacobians

Generate a batch of identity Jacobian matrices for keypoints.

Each Jacobian is a 2x2 identity matrix, indicating no local transformation. The output is used in motion transfer models where identity Jacobians are assigned to keypoints that do not contribute to deformation.

Parameters: batch (int): Number of batches (B), typically corresponding to the batch size of input images. num_kp (int): Number of keypoints (n) per image.

Returns: tf.Tensor: A tensor of shape (B, n, 2, 2) containing identity matrices for each keypoint in the batch.

def train_motion_model( X, Y, generator, gan_model, discriminator, batch_size=16, epochs=10, save_path='./working', preview=True):
 42def train_motion_model(X, Y,
 43                       generator,
 44                       gan_model,
 45                       discriminator,
 46                       batch_size=16,
 47                       epochs=10,
 48                       save_path='./working',
 49                       preview=True):
 50    """
 51    Trains the generator and GAN model using the provided datasets.
 52
 53    Args:
 54        X (tuple): Tuple of (X0, X1) input arrays.
 55        Y (np.ndarray): Ground truth output images.
 56        generator (keras.Model): Generator model.
 57        gan_model (keras.Model): GAN model combining generator and discriminator.
 58        discriminator (keras.Model): Discriminator model.
 59        batch_size (int): Number of samples per training batch.
 60        epochs (int): Number of training epochs.
 61        save_path (str): Path to save model weights.
 62        preview (bool): Whether to show output images during training.
 63    """
 64    total_inst = (Y.shape[0] - 1) // batch_size * batch_size
 65    X0, X1 = X[0][:total_inst], X[1][:total_inst]
 66    Y = Y[:total_inst]
 67
 68    generator.layers[-1].batch_size = batch_size
 69
 70    for epoch in range(epochs):
 71        print(f"Epoch {epoch + 1}/{epochs}")
 72        for batch_number in tqdm(range(0, total_inst, batch_size), desc="Batches"):
 73            x_batch_0 = X0[batch_number: batch_number + batch_size]
 74            x_batch_1 = X1[batch_number: batch_number + batch_size]
 75            y_batch = Y[batch_number: batch_number + batch_size]
 76
 77            # Train generator directly (supervised)
 78            generator.train_on_batch((x_batch_0, x_batch_1), y_batch)
 79
 80            # Train discriminator to distinguish real vs fake
 81            loss_real = discriminator.train_on_batch(x_batch_1, keras.ops.zeros((batch_size, 1)))
 82            print('Discriminator Loss (real):', loss_real)
 83
 84            # Freeze discriminator for generator adversarial training
 85            gan_model.trainable = False
 86            gan_model.layers[-1].trainable = True
 87
 88            loss_gen = gan_model.train_on_batch((x_batch_0, x_batch_1), keras.ops.ones((batch_size, 1)))
 89            print('GAN Loss (G):', loss_gen)
 90
 91            # Re-enable discriminator for adversarial step
 92            gan_model.trainable = True
 93            gan_model.layers[-1].trainable = False
 94
 95            loss_fake = gan_model.train_on_batch((x_batch_0, x_batch_1), keras.ops.zeros((batch_size, 1)))
 96            print('GAN Loss (D):', loss_fake)
 97
 98        # Visualization after each epoch
 99        if preview:
100            pred = generator((x_batch_0, x_batch_1))[0, ..., 0]
101            plt.imshow(pred, cmap='gray')
102            plt.title(f"Epoch {epoch + 1}")
103            plt.axis('off')
104            plt.show()
105
106        # Save model weights
107        generator.save_weights(f'{save_path}/generator.weights.h5')
108        gan_model.save_weights(f'{save_path}/GAN.weights.h5')
109
110    print("Training complete.")

Trains the generator and GAN model using the provided datasets.

Args: X (tuple): Tuple of (X0, X1) input arrays. Y (np.ndarray): Ground truth output images. generator (keras.Model): Generator model. gan_model (keras.Model): GAN model combining generator and discriminator. discriminator (keras.Model): Discriminator model. batch_size (int): Number of samples per training batch. epochs (int): Number of training epochs. save_path (str): Path to save model weights. preview (bool): Whether to show output images during training.

def setup_keypoint_pipeline( keypoint_detector, generator, discriminator_model, image_size=(256, 256, 1), batch_size=16, warmup_samples=500, warmup_epochs=10, training_epochs=250, num_keypoints=10, learning_rate=0.0001):
113def setup_keypoint_pipeline(
114    keypoint_detector,
115    generator,
116    discriminator_model,
117    image_size=(256, 256, 1),
118    batch_size=16,
119    warmup_samples=500,
120    warmup_epochs=10,
121    training_epochs=250,
122    num_keypoints=10,
123    learning_rate=1e-4,
124):
125    """
126    Sets up a general training pipeline for keypoint-based image generation using GAN.
127
128    Returns:
129        - GAN model
130        - Generator model
131        - Keypoint detector
132        - Discriminator
133        - Aligner (warmup model)
134    """
135
136
137    # ------------------------------
138    #   Warmup keypoint detector (align jacobians)
139    # ------------------------------
140    kp_input = keras.Input(shape=image_size)
141    kp_output = keypoint_detector(kp_input)
142    kp_aligner = keras.Model(inputs=kp_input, outputs=kp_output[1])
143    kp_aligner.compile(optimizer='adam', loss='mse')
144
145    # Dummy warmup training
146    kp_aligner.fit(
147        keras.random.normal((warmup_samples, *image_size)),
148        generate_identity_jacobians(warmup_samples, num_keypoints),
149        batch_size=50,
150        epochs=warmup_epochs
151    )
152
153    # ------------------------------
154    #   Set up GAN pipeline
155    # ------------------------------
156    src_input = keras.Input(shape=image_size)
157    drv_input = keras.Input(shape=image_size)
158
159    src_kp = keypoint_detector(src_input)
160    drv_kp = keypoint_detector(drv_input)
161
162    gen_out = generator((src_input, src_kp[0], src_kp[1], drv_kp[0], drv_kp[1]))
163
164    generator_model = keras.Model(inputs=[src_input, drv_input], outputs=gen_out)
165    generator_model.compile(
166        optimizer=keras.optimizers.Adam(learning_rate),
167        loss='mse',
168        run_eagerly=False
169    )
170
171    discriminator_model.compile(
172        optimizer=keras.optimizers.Adam(learning_rate),
173        loss='binary_crossentropy',
174        run_eagerly=False
175    )
176
177    # GAN pipeline with frozen discriminator
178    disc_out = discriminator_model(generator((src_input, src_kp[0], src_kp[1], drv_kp[0], drv_kp[1])))
179    gan_model = keras.Model(inputs=[src_input, drv_input], outputs=disc_out)
180    gan_model.compile(
181        optimizer=keras.optimizers.Adam(learning_rate),
182        loss='binary_crossentropy',
183        run_eagerly=False
184    )
185
186    # Debugging summaries
187    print("\n🧱 GAN Summary:")
188    gan_model.summary()
189    print("\n🧱 Generator Backbone Summary:")
190    generator.model.summary()
191    print("\n🧱 Generator Upscaler Summary:")
192    generator.upscaler.summary()
193
194    return {
195        "gan": gan_model,
196        "generator_model": generator_model,
197        "keypoint_detector": keypoint_detector,
198        "discriminator": discriminator_model,
199        "aligner": kp_aligner
200    }

Sets up a general training pipeline for keypoint-based image generation using GAN.

Returns: - GAN model - Generator model - Keypoint detector - Discriminator - Aligner (warmup model)