src.utils.training_utils
1import keras 2import keras.ops as ops 3import numpy as np 4import matplotlib.pyplot as plt 5from tqdm import tqdm 6 7from src.models.discriminator import * 8from src.models.generator import * 9from src.models.keypoint_detector import * 10 11def generate_identity_jacobians(batch, num_kp): 12 """ 13 Generate a batch of identity Jacobian matrices for keypoints. 14 15 Each Jacobian is a 2x2 identity matrix, indicating no local transformation. 16 The output is used in motion transfer models where identity Jacobians are 17 assigned to keypoints that do not contribute to deformation. 18 19 Parameters: 20 batch (int): Number of batches (B), typically corresponding to the batch size of input images. 21 num_kp (int): Number of keypoints (n) per image. 22 23 Returns: 24 tf.Tensor: A tensor of shape (B, n, 2, 2) containing identity matrices for each keypoint in the batch. 25 """ 26 identity_matrix = ops.convert_to_tensor([[1, 0], [0, 1]], dtype="float32") # (2, 2) 27 28 # Expand to (B, n, 2, 2) 29 identity_jacobians = ops.expand_dims(identity_matrix, axis=0) # (1, 2, 2) 30 identity_jacobians = ops.repeat(identity_jacobians, batch * num_kp, axis=0) 31 identity_jacobians = ops.reshape(identity_jacobians, (batch, num_kp, 2, 2)) 32 33 return identity_jacobians 34 35 36import numpy as np 37import matplotlib.pyplot as plt 38from tqdm import tqdm 39from tensorflow import keras 40 41def train_motion_model(X, Y, 42 generator, 43 gan_model, 44 discriminator, 45 batch_size=16, 46 epochs=10, 47 save_path='./working', 48 preview=True): 49 """ 50 Trains the generator and GAN model using the provided datasets. 51 52 Args: 53 X (tuple): Tuple of (X0, X1) input arrays. 54 Y (np.ndarray): Ground truth output images. 55 generator (keras.Model): Generator model. 56 gan_model (keras.Model): GAN model combining generator and discriminator. 57 discriminator (keras.Model): Discriminator model. 58 batch_size (int): Number of samples per training batch. 59 epochs (int): Number of training epochs. 60 save_path (str): Path to save model weights. 61 preview (bool): Whether to show output images during training. 62 """ 63 total_inst = (Y.shape[0] - 1) // batch_size * batch_size 64 X0, X1 = X[0][:total_inst], X[1][:total_inst] 65 Y = Y[:total_inst] 66 67 generator.layers[-1].batch_size = batch_size 68 69 for epoch in range(epochs): 70 print(f"Epoch {epoch + 1}/{epochs}") 71 for batch_number in tqdm(range(0, total_inst, batch_size), desc="Batches"): 72 x_batch_0 = X0[batch_number: batch_number + batch_size] 73 x_batch_1 = X1[batch_number: batch_number + batch_size] 74 y_batch = Y[batch_number: batch_number + batch_size] 75 76 # Train generator directly (supervised) 77 generator.train_on_batch((x_batch_0, x_batch_1), y_batch) 78 79 # Train discriminator to distinguish real vs fake 80 loss_real = discriminator.train_on_batch(x_batch_1, keras.ops.zeros((batch_size, 1))) 81 print('Discriminator Loss (real):', loss_real) 82 83 # Freeze discriminator for generator adversarial training 84 gan_model.trainable = False 85 gan_model.layers[-1].trainable = True 86 87 loss_gen = gan_model.train_on_batch((x_batch_0, x_batch_1), keras.ops.ones((batch_size, 1))) 88 print('GAN Loss (G):', loss_gen) 89 90 # Re-enable discriminator for adversarial step 91 gan_model.trainable = True 92 gan_model.layers[-1].trainable = False 93 94 loss_fake = gan_model.train_on_batch((x_batch_0, x_batch_1), keras.ops.zeros((batch_size, 1))) 95 print('GAN Loss (D):', loss_fake) 96 97 # Visualization after each epoch 98 if preview: 99 pred = generator((x_batch_0, x_batch_1))[0, ..., 0] 100 plt.imshow(pred, cmap='gray') 101 plt.title(f"Epoch {epoch + 1}") 102 plt.axis('off') 103 plt.show() 104 105 # Save model weights 106 generator.save_weights(f'{save_path}/generator.weights.h5') 107 gan_model.save_weights(f'{save_path}/GAN.weights.h5') 108 109 print("Training complete.") 110 111 112def setup_keypoint_pipeline( 113 keypoint_detector, 114 generator, 115 discriminator_model, 116 image_size=(256, 256, 1), 117 batch_size=16, 118 warmup_samples=500, 119 warmup_epochs=10, 120 training_epochs=250, 121 num_keypoints=10, 122 learning_rate=1e-4, 123): 124 """ 125 Sets up a general training pipeline for keypoint-based image generation using GAN. 126 127 Returns: 128 - GAN model 129 - Generator model 130 - Keypoint detector 131 - Discriminator 132 - Aligner (warmup model) 133 """ 134 135 136 # ------------------------------ 137 # Warmup keypoint detector (align jacobians) 138 # ------------------------------ 139 kp_input = keras.Input(shape=image_size) 140 kp_output = keypoint_detector(kp_input) 141 kp_aligner = keras.Model(inputs=kp_input, outputs=kp_output[1]) 142 kp_aligner.compile(optimizer='adam', loss='mse') 143 144 # Dummy warmup training 145 kp_aligner.fit( 146 keras.random.normal((warmup_samples, *image_size)), 147 generate_identity_jacobians(warmup_samples, num_keypoints), 148 batch_size=50, 149 epochs=warmup_epochs 150 ) 151 152 # ------------------------------ 153 # Set up GAN pipeline 154 # ------------------------------ 155 src_input = keras.Input(shape=image_size) 156 drv_input = keras.Input(shape=image_size) 157 158 src_kp = keypoint_detector(src_input) 159 drv_kp = keypoint_detector(drv_input) 160 161 gen_out = generator((src_input, src_kp[0], src_kp[1], drv_kp[0], drv_kp[1])) 162 163 generator_model = keras.Model(inputs=[src_input, drv_input], outputs=gen_out) 164 generator_model.compile( 165 optimizer=keras.optimizers.Adam(learning_rate), 166 loss='mse', 167 run_eagerly=False 168 ) 169 170 discriminator_model.compile( 171 optimizer=keras.optimizers.Adam(learning_rate), 172 loss='binary_crossentropy', 173 run_eagerly=False 174 ) 175 176 # GAN pipeline with frozen discriminator 177 disc_out = discriminator_model(generator((src_input, src_kp[0], src_kp[1], drv_kp[0], drv_kp[1]))) 178 gan_model = keras.Model(inputs=[src_input, drv_input], outputs=disc_out) 179 gan_model.compile( 180 optimizer=keras.optimizers.Adam(learning_rate), 181 loss='binary_crossentropy', 182 run_eagerly=False 183 ) 184 185 # Debugging summaries 186 print("\n🧱 GAN Summary:") 187 gan_model.summary() 188 print("\n🧱 Generator Backbone Summary:") 189 generator.model.summary() 190 print("\n🧱 Generator Upscaler Summary:") 191 generator.upscaler.summary() 192 193 return { 194 "gan": gan_model, 195 "generator_model": generator_model, 196 "keypoint_detector": keypoint_detector, 197 "discriminator": discriminator_model, 198 "aligner": kp_aligner 199 }
12def generate_identity_jacobians(batch, num_kp): 13 """ 14 Generate a batch of identity Jacobian matrices for keypoints. 15 16 Each Jacobian is a 2x2 identity matrix, indicating no local transformation. 17 The output is used in motion transfer models where identity Jacobians are 18 assigned to keypoints that do not contribute to deformation. 19 20 Parameters: 21 batch (int): Number of batches (B), typically corresponding to the batch size of input images. 22 num_kp (int): Number of keypoints (n) per image. 23 24 Returns: 25 tf.Tensor: A tensor of shape (B, n, 2, 2) containing identity matrices for each keypoint in the batch. 26 """ 27 identity_matrix = ops.convert_to_tensor([[1, 0], [0, 1]], dtype="float32") # (2, 2) 28 29 # Expand to (B, n, 2, 2) 30 identity_jacobians = ops.expand_dims(identity_matrix, axis=0) # (1, 2, 2) 31 identity_jacobians = ops.repeat(identity_jacobians, batch * num_kp, axis=0) 32 identity_jacobians = ops.reshape(identity_jacobians, (batch, num_kp, 2, 2)) 33 34 return identity_jacobians
Generate a batch of identity Jacobian matrices for keypoints.
Each Jacobian is a 2x2 identity matrix, indicating no local transformation. The output is used in motion transfer models where identity Jacobians are assigned to keypoints that do not contribute to deformation.
Parameters: batch (int): Number of batches (B), typically corresponding to the batch size of input images. num_kp (int): Number of keypoints (n) per image.
Returns: tf.Tensor: A tensor of shape (B, n, 2, 2) containing identity matrices for each keypoint in the batch.
42def train_motion_model(X, Y, 43 generator, 44 gan_model, 45 discriminator, 46 batch_size=16, 47 epochs=10, 48 save_path='./working', 49 preview=True): 50 """ 51 Trains the generator and GAN model using the provided datasets. 52 53 Args: 54 X (tuple): Tuple of (X0, X1) input arrays. 55 Y (np.ndarray): Ground truth output images. 56 generator (keras.Model): Generator model. 57 gan_model (keras.Model): GAN model combining generator and discriminator. 58 discriminator (keras.Model): Discriminator model. 59 batch_size (int): Number of samples per training batch. 60 epochs (int): Number of training epochs. 61 save_path (str): Path to save model weights. 62 preview (bool): Whether to show output images during training. 63 """ 64 total_inst = (Y.shape[0] - 1) // batch_size * batch_size 65 X0, X1 = X[0][:total_inst], X[1][:total_inst] 66 Y = Y[:total_inst] 67 68 generator.layers[-1].batch_size = batch_size 69 70 for epoch in range(epochs): 71 print(f"Epoch {epoch + 1}/{epochs}") 72 for batch_number in tqdm(range(0, total_inst, batch_size), desc="Batches"): 73 x_batch_0 = X0[batch_number: batch_number + batch_size] 74 x_batch_1 = X1[batch_number: batch_number + batch_size] 75 y_batch = Y[batch_number: batch_number + batch_size] 76 77 # Train generator directly (supervised) 78 generator.train_on_batch((x_batch_0, x_batch_1), y_batch) 79 80 # Train discriminator to distinguish real vs fake 81 loss_real = discriminator.train_on_batch(x_batch_1, keras.ops.zeros((batch_size, 1))) 82 print('Discriminator Loss (real):', loss_real) 83 84 # Freeze discriminator for generator adversarial training 85 gan_model.trainable = False 86 gan_model.layers[-1].trainable = True 87 88 loss_gen = gan_model.train_on_batch((x_batch_0, x_batch_1), keras.ops.ones((batch_size, 1))) 89 print('GAN Loss (G):', loss_gen) 90 91 # Re-enable discriminator for adversarial step 92 gan_model.trainable = True 93 gan_model.layers[-1].trainable = False 94 95 loss_fake = gan_model.train_on_batch((x_batch_0, x_batch_1), keras.ops.zeros((batch_size, 1))) 96 print('GAN Loss (D):', loss_fake) 97 98 # Visualization after each epoch 99 if preview: 100 pred = generator((x_batch_0, x_batch_1))[0, ..., 0] 101 plt.imshow(pred, cmap='gray') 102 plt.title(f"Epoch {epoch + 1}") 103 plt.axis('off') 104 plt.show() 105 106 # Save model weights 107 generator.save_weights(f'{save_path}/generator.weights.h5') 108 gan_model.save_weights(f'{save_path}/GAN.weights.h5') 109 110 print("Training complete.")
Trains the generator and GAN model using the provided datasets.
Args: X (tuple): Tuple of (X0, X1) input arrays. Y (np.ndarray): Ground truth output images. generator (keras.Model): Generator model. gan_model (keras.Model): GAN model combining generator and discriminator. discriminator (keras.Model): Discriminator model. batch_size (int): Number of samples per training batch. epochs (int): Number of training epochs. save_path (str): Path to save model weights. preview (bool): Whether to show output images during training.
113def setup_keypoint_pipeline( 114 keypoint_detector, 115 generator, 116 discriminator_model, 117 image_size=(256, 256, 1), 118 batch_size=16, 119 warmup_samples=500, 120 warmup_epochs=10, 121 training_epochs=250, 122 num_keypoints=10, 123 learning_rate=1e-4, 124): 125 """ 126 Sets up a general training pipeline for keypoint-based image generation using GAN. 127 128 Returns: 129 - GAN model 130 - Generator model 131 - Keypoint detector 132 - Discriminator 133 - Aligner (warmup model) 134 """ 135 136 137 # ------------------------------ 138 # Warmup keypoint detector (align jacobians) 139 # ------------------------------ 140 kp_input = keras.Input(shape=image_size) 141 kp_output = keypoint_detector(kp_input) 142 kp_aligner = keras.Model(inputs=kp_input, outputs=kp_output[1]) 143 kp_aligner.compile(optimizer='adam', loss='mse') 144 145 # Dummy warmup training 146 kp_aligner.fit( 147 keras.random.normal((warmup_samples, *image_size)), 148 generate_identity_jacobians(warmup_samples, num_keypoints), 149 batch_size=50, 150 epochs=warmup_epochs 151 ) 152 153 # ------------------------------ 154 # Set up GAN pipeline 155 # ------------------------------ 156 src_input = keras.Input(shape=image_size) 157 drv_input = keras.Input(shape=image_size) 158 159 src_kp = keypoint_detector(src_input) 160 drv_kp = keypoint_detector(drv_input) 161 162 gen_out = generator((src_input, src_kp[0], src_kp[1], drv_kp[0], drv_kp[1])) 163 164 generator_model = keras.Model(inputs=[src_input, drv_input], outputs=gen_out) 165 generator_model.compile( 166 optimizer=keras.optimizers.Adam(learning_rate), 167 loss='mse', 168 run_eagerly=False 169 ) 170 171 discriminator_model.compile( 172 optimizer=keras.optimizers.Adam(learning_rate), 173 loss='binary_crossentropy', 174 run_eagerly=False 175 ) 176 177 # GAN pipeline with frozen discriminator 178 disc_out = discriminator_model(generator((src_input, src_kp[0], src_kp[1], drv_kp[0], drv_kp[1]))) 179 gan_model = keras.Model(inputs=[src_input, drv_input], outputs=disc_out) 180 gan_model.compile( 181 optimizer=keras.optimizers.Adam(learning_rate), 182 loss='binary_crossentropy', 183 run_eagerly=False 184 ) 185 186 # Debugging summaries 187 print("\n🧱 GAN Summary:") 188 gan_model.summary() 189 print("\n🧱 Generator Backbone Summary:") 190 generator.model.summary() 191 print("\n🧱 Generator Upscaler Summary:") 192 generator.upscaler.summary() 193 194 return { 195 "gan": gan_model, 196 "generator_model": generator_model, 197 "keypoint_detector": keypoint_detector, 198 "discriminator": discriminator_model, 199 "aligner": kp_aligner 200 }
Sets up a general training pipeline for keypoint-based image generation using GAN.
Returns: - GAN model - Generator model - Keypoint detector - Discriminator - Aligner (warmup model)