src.utils.model_utils
1import keras 2import keras.ops as ops 3import numpy as np 4import matplotlib.pyplot as plt 5from tqdm import tqdm 6 7def grid_transform(inputs, grid, order=1, fill_mode="constant", fill_value=0, batch = None): 8 """ 9 Applies a spatial transformation to the input tensor using a sampling grid. 10 11 This function performs image warping by sampling the `inputs` tensor at 12 specified `grid` coordinates using interpolation. The grid values are expected 13 to be normalized to the range [-1, 1], where -1 and 1 correspond to the 14 top-left and bottom-right corners respectively. 15 16 Parameters: 17 inputs (tf.Tensor): Input tensor of shape (B, H, W, C), representing a batch 18 of images. 19 grid (tf.Tensor): Sampling grid of shape (B, H_out, W_out, 2), containing 20 normalized coordinates (x, y) to sample from the input. 21 order (int): Interpolation order. 1 for bilinear, 0 for nearest neighbor. 22 fill_mode (str): Points outside the boundaries are filled according to this 23 mode. Supported: "constant", "nearest", "reflect", etc. 24 fill_value (float): Value to use for points sampled outside the boundaries 25 when `fill_mode="constant"`. 26 batch (int, optional): Manually specify batch size if it cannot be inferred 27 from inputs (e.g., in a tracing context). 28 29 Returns: 30 tf.Tensor: Warped output tensor of shape (B, H_out, W_out, C), where each 31 image has been resampled using the provided grid. 32 """ 33 34 # Assume inputs has static shape (B, H, W, C) 35 36 if batch is not None: 37 B = batch 38 else: 39 B = inputs.shape[0] 40 H = inputs.shape[1] 41 W = inputs.shape[2] 42 C = inputs.shape[3] 43 44 # Dynamic output grid dims 45 grid_shape = ops.shape(grid) 46 H_out = grid_shape[1] 47 W_out = grid_shape[2] 48 49 # Convert normalized grid coordinates to pixel coordinates. 50 # grid[...,0] is x, grid[...,1] is y. 51 x = (grid[..., 0] + 1.0) * (ops.cast(W, "float32") - 1) / 2.0 52 y = (grid[..., 1] + 1.0) * (ops.cast(H, "float32") - 1) / 2.0 53 54 outputs_list = [] 55 56 for b in range(B): 57 channels_out = [] 58 # Swap coordinate order: (y, x) for image indexing. 59 coords = ops.stack([y[b], x[b]], axis=-1) # shape: (H_out, W_out, 2) 60 coords_flat = ops.reshape(coords, (-1, 2)) # shape: (H_out*W_out, 2) 61 # Transpose to shape (2, N) 62 coords_flat = ops.transpose(coords_flat, [1, 0]) 63 64 for c in range(C): 65 channel_img = inputs[b, :, :, c] # shape: (H, W) 66 if fill_mode == "constant": 67 # Pad image with one pixel on each side. 68 padded_img = ops.pad(channel_img, [[1, 1], [1, 1]], constant_values=fill_value) 69 padded_H = H + 2 70 padded_W = W + 2 71 # Adjust coordinates by +1 to index into padded image. 72 coords_adj = coords_flat + 1.0 73 # For bilinear interpolation, we need floor and ceil of each coordinate. 74 y_coords = coords_adj[0] 75 x_coords = coords_adj[1] 76 y0 = ops.floor(y_coords) 77 y1 = ops.ceil(y_coords) 78 x0 = ops.floor(x_coords) 79 x1 = ops.ceil(x_coords) 80 # valid if both floor and ceil are within [0, padded_dim - 1]. 81 valid = (y0 >= 0) & (y1 < padded_H) & (x0 >= 0) & (x1 < padded_W) 82 83 # For sampling, clip coordinates into valid range. 84 y_clip = ops.clip(y_coords, 0, padded_H - 1) 85 x_clip = ops.clip(x_coords, 0, padded_W - 1) 86 coords_clipped = ops.stack([y_clip, x_clip], axis=0) 87 88 # Use map_coordinates with a safe fill_mode. 89 sampled = keras.ops.image.map_coordinates( 90 padded_img, 91 coords_clipped, 92 order=order, 93 fill_mode="nearest", # We'll override invalid positions below. 94 fill_value=fill_value 95 ) 96 # Replace values at invalid positions with fill_value. 97 valid = ops.cast(valid, "float32") # shape: (N,) 98 sampled = valid * sampled + (1 - valid) * fill_value 99 else: 100 # For nonconstant fill modes, we use the coordinates as is. 101 sampled = keras.ops.image.map_coordinates( 102 channel_img, 103 coords_flat, 104 order=order, 105 fill_mode=fill_mode, 106 fill_value=fill_value 107 ) 108 sampled_reshaped = ops.reshape(sampled, (H_out, W_out)) 109 channels_out.append(sampled_reshaped) 110 batch_out = ops.stack(channels_out, axis=-1) # (H_out, W_out, C) 111 outputs_list.append(batch_out) 112 113 outputs = ops.stack(outputs_list, axis=0) # (B, H_out, W_out, C) 114 return outputs 115 116def generate_grid(shape): 117 118 """ 119 Generates a normalized 2D sampling grid for spatial transformations. 120 121 The grid contains coordinates in the range [-1, 1], where: 122 - x varies from -1 (left) to 1 (right) 123 - y varies from 1 (top) to -1 (bottom), following image coordinates 124 125 This grid is used for resampling operations like spatial warping or 126 optical flow-based motion transfer. 127 128 Parameters: 129 shape (tuple): A 3-tuple (batch_size, height, width) specifying the 130 desired grid shape. 131 132 Returns: 133 tf.Tensor: A grid tensor of shape (batch_size, height, width, 2), where 134 each position contains (x, y) coordinates normalized to [-1, 1]. 135 """ 136 batch_size, h, w = shape 137 138 # Generate normalized coordinates 139 x = ops.linspace(-1.0, 1.0, w) # Linearly spaced x-coordinates 140 y = ops.linspace(1.0, -1.0, h) # Linearly spaced y-coordinates (flip y-axis) 141 142 # Create meshgrid 143 X, Y = ops.meshgrid(x, y, indexing='xy') # Shape (h, w) 144 145 # Stack X and Y into (h, w, 2) 146 grid = ops.stack([X, Y], axis=-1) 147 148 # Expand for batch dimension (batch, h, w, 2) 149 grid = ops.expand_dims(grid, axis=0) 150 grid = ops.repeat(grid, batch_size, axis=0) 151 152 return grid # Shape: (batch_size, h, w, 2) 153 154 155 156def grid_to_gaussian(grid, variance=0.01): 157 158 """ 159 Converts a coordinate grid into a 2D isotropic Gaussian heatmap. 160 161 Each point in the grid is interpreted as a coordinate offset from the center, 162 and the function computes the corresponding Gaussian activation based on 163 the squared distance from the origin. 164 165 This is commonly used to convert keypoint offsets or sampling grids into 166 soft spatial attention maps. 167 168 Parameters: 169 grid (tf.Tensor): Tensor of shape (batch, h, w, 2), where each (x, y) pair 170 represents an offset from the center. 171 variance (float): Scalar controlling the spread of the Gaussian (default 0.01). 172 Smaller values produce sharper peaks. 173 174 Returns: 175 tf.Tensor: A tensor of shape (batch, h, w, 1) representing the Gaussian 176 probability distribution over the grid. 177 """ 178 179 squared_distance = ops.sum(ops.square(grid), axis=-1, keepdims=True) # x^2 + y^2 180 gaussian = ops.exp(-0.5 * squared_distance / variance) # Apply Gaussian formula 181 return gaussian # Shape: (batch, h, w, 1) 182 183 184def kp2gaussian(keypoints, variance, image_shape, batch_size=None): 185 """ 186 Converts normalized keypoint coordinates into 2D Gaussian heatmaps. 187 188 Parameters: 189 keypoints (tf.Tensor): Tensor of shape (B, n, 2), where each keypoint 190 is (x, y) in normalized coordinates. 191 variance (float): Scalar controlling the spread of the Gaussian. 192 image_shape (tuple): Tuple (B, h, w) specifying the output heatmap dimensions. 193 batch_size (int, optional): Optional override for batch size, used when 194 tracing or shape inference is needed. 195 196 Returns: 197 tf.Tensor: A tensor of shape (B, n, h, w) containing a Gaussian heatmap 198 for each keypoint. 199 """ 200 if batch_size is None: 201 B, h, w = image_shape 202 else: 203 B = batch_size 204 _, h, w = image_shape 205 206 grid = generate_grid((B, h, w)) # (B, h, w, 2) 207 grid_exp = ops.expand_dims(grid, axis=1) # (B, 1, h, w, 2) 208 kp_exp = ops.expand_dims(ops.expand_dims(keypoints, axis=2), axis=2) # (B, n, 1, 1, 2) 209 210 delta = grid_exp - kp_exp # (B, n, h, w, 2) 211 squared_distance = ops.sum(ops.square(delta), axis=-1, keepdims=True) # (B, n, h, w, 1) 212 gaussian = ops.exp(-0.5 * squared_distance / variance) # (B, n, h, w, 1) 213 214 return ops.squeeze(gaussian, axis=-1) # (B, n, h, w) 215 216 217def get_keypoint_coordinates(image, grid): 218 """ 219 Computes the keypoint coordinates from a weighted image and spatial grid. 220 221 This function estimates keypoint positions by computing the spatial 222 expectation over the grid, weighted by pixel values in the input image. 223 224 Parameters: 225 image (tf.Tensor): Tensor of shape (B, H, W, C), typically a heatmap 226 or attention map per keypoint. 227 grid (tf.Tensor): Tensor of shape (B, H, W, 2), containing normalized 228 (x, y) coordinates in the range [-1, 1]. 229 230 Returns: 231 tf.Tensor: Tensor of shape (B, 2, C) containing keypoint coordinates 232 (x, y) for each channel (usually one channel per keypoint). 233 """ 234 grid_x = ops.expand_dims(grid[..., 0], axis=-1) # (B, H, W, 1) 235 grid_y = ops.expand_dims(grid[..., 1], axis=-1) # (B, H, W, 1) 236 237 latent_x = image * grid_x # (B, H, W, C) 238 latent_y = image * grid_y # (B, H, W, C) 239 240 kp_x = ops.sum(latent_x, axis=[1, 2]) # (B, C) 241 kp_y = ops.sum(latent_y, axis=[1, 2]) # (B, C) 242 243 return ops.stack([kp_x, kp_y], axis=1) # (B, 2, C) 244 245def sparse_motion(image_shape, 246 sparse_image_keypoints, 247 driving_image_keypoints, 248 sparse_image_jacobians, 249 driving_image_jacobians): 250 """ 251 Computes a per-keypoint motion field by transforming pixel-wise coordinates 252 from the driving frame to the source (sparse) frame using Jacobian-based warping. 253 254 Parameters: 255 image_shape (tuple): Tuple (B, h, w) representing batch size and output resolution. 256 sparse_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), source keypoints. 257 driving_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), driving keypoints. 258 sparse_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at source keypoints. 259 driving_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at driving keypoints. 260 261 Returns: 262 tf.Tensor: Transformed coordinate grid of shape (B, n, h, w, 2). 263 """ 264 B, h, w = image_shape 265 n = ops.shape(sparse_image_keypoints)[1] 266 267 grid = ops.expand_dims(generate_grid(image_shape), axis=1) # (B, 1, h, w, 2) 268 grid = ops.repeat(grid, n, axis=1) # (B, n, h, w, 2) 269 270 driving_kp = ops.expand_dims(ops.expand_dims(driving_image_keypoints, axis=2), axis=3) # (B, n, 1, 1, 2) 271 diff = grid - driving_kp # (B, n, h, w, 2) 272 273 inv_driving = ops.linalg.inv(driving_image_jacobians + ops.eye(2, 2) * 1e-6) # (B, n, 2, 2) 274 T = ops.matmul(sparse_image_jacobians, inv_driving) # (B, n, 2, 2) 275 276 diff_transformed = ops.einsum('bnij,bnxyj->bnxyi', T, diff) # (B, n, h, w, 2) 277 278 sparse_kp = ops.expand_dims(ops.expand_dims(sparse_image_keypoints, axis=2), axis=3) 279 return diff_transformed + sparse_kp # (B, n, h, w, 2) 280 281def apply_sparse_motion(image, motion_field, order=1, fill_mode="constant", fill_value=0, batch_size=None): 282 """ 283 Applies a sparse motion field to an image using interpolation. 284 285 Parameters: 286 image (Tensor): Input tensor of shape (B, h, w, 1) or (B, h, w). 287 motion_field (Tensor): Motion field of shape (B, n, h, w, 2), one per keypoint. 288 order (int): Interpolation order. 289 fill_mode (str): Fill mode for sampling. 290 fill_value (float): Fill value for constant mode. 291 batch_size (int, optional): Batch size override. 292 293 Returns: 294 Tensor: Output of shape (B, h, w, n), one warped image per keypoint. 295 """ 296 if len(image.shape) == 3: 297 image = ops.expand_dims(image, axis=-1) 298 299 B = batch_size if batch_size is not None else image.shape[0] 300 n = motion_field.shape[1] 301 302 deformed = [] 303 for i in range(n): 304 grid_i = motion_field[:, i] 305 warped = grid_transform(image, grid_i, order=order, fill_mode=fill_mode, fill_value=fill_value, batch=B) 306 deformed.append(warped) 307 308 return ops.concatenate(deformed, axis=-1)
8def grid_transform(inputs, grid, order=1, fill_mode="constant", fill_value=0, batch = None): 9 """ 10 Applies a spatial transformation to the input tensor using a sampling grid. 11 12 This function performs image warping by sampling the `inputs` tensor at 13 specified `grid` coordinates using interpolation. The grid values are expected 14 to be normalized to the range [-1, 1], where -1 and 1 correspond to the 15 top-left and bottom-right corners respectively. 16 17 Parameters: 18 inputs (tf.Tensor): Input tensor of shape (B, H, W, C), representing a batch 19 of images. 20 grid (tf.Tensor): Sampling grid of shape (B, H_out, W_out, 2), containing 21 normalized coordinates (x, y) to sample from the input. 22 order (int): Interpolation order. 1 for bilinear, 0 for nearest neighbor. 23 fill_mode (str): Points outside the boundaries are filled according to this 24 mode. Supported: "constant", "nearest", "reflect", etc. 25 fill_value (float): Value to use for points sampled outside the boundaries 26 when `fill_mode="constant"`. 27 batch (int, optional): Manually specify batch size if it cannot be inferred 28 from inputs (e.g., in a tracing context). 29 30 Returns: 31 tf.Tensor: Warped output tensor of shape (B, H_out, W_out, C), where each 32 image has been resampled using the provided grid. 33 """ 34 35 # Assume inputs has static shape (B, H, W, C) 36 37 if batch is not None: 38 B = batch 39 else: 40 B = inputs.shape[0] 41 H = inputs.shape[1] 42 W = inputs.shape[2] 43 C = inputs.shape[3] 44 45 # Dynamic output grid dims 46 grid_shape = ops.shape(grid) 47 H_out = grid_shape[1] 48 W_out = grid_shape[2] 49 50 # Convert normalized grid coordinates to pixel coordinates. 51 # grid[...,0] is x, grid[...,1] is y. 52 x = (grid[..., 0] + 1.0) * (ops.cast(W, "float32") - 1) / 2.0 53 y = (grid[..., 1] + 1.0) * (ops.cast(H, "float32") - 1) / 2.0 54 55 outputs_list = [] 56 57 for b in range(B): 58 channels_out = [] 59 # Swap coordinate order: (y, x) for image indexing. 60 coords = ops.stack([y[b], x[b]], axis=-1) # shape: (H_out, W_out, 2) 61 coords_flat = ops.reshape(coords, (-1, 2)) # shape: (H_out*W_out, 2) 62 # Transpose to shape (2, N) 63 coords_flat = ops.transpose(coords_flat, [1, 0]) 64 65 for c in range(C): 66 channel_img = inputs[b, :, :, c] # shape: (H, W) 67 if fill_mode == "constant": 68 # Pad image with one pixel on each side. 69 padded_img = ops.pad(channel_img, [[1, 1], [1, 1]], constant_values=fill_value) 70 padded_H = H + 2 71 padded_W = W + 2 72 # Adjust coordinates by +1 to index into padded image. 73 coords_adj = coords_flat + 1.0 74 # For bilinear interpolation, we need floor and ceil of each coordinate. 75 y_coords = coords_adj[0] 76 x_coords = coords_adj[1] 77 y0 = ops.floor(y_coords) 78 y1 = ops.ceil(y_coords) 79 x0 = ops.floor(x_coords) 80 x1 = ops.ceil(x_coords) 81 # valid if both floor and ceil are within [0, padded_dim - 1]. 82 valid = (y0 >= 0) & (y1 < padded_H) & (x0 >= 0) & (x1 < padded_W) 83 84 # For sampling, clip coordinates into valid range. 85 y_clip = ops.clip(y_coords, 0, padded_H - 1) 86 x_clip = ops.clip(x_coords, 0, padded_W - 1) 87 coords_clipped = ops.stack([y_clip, x_clip], axis=0) 88 89 # Use map_coordinates with a safe fill_mode. 90 sampled = keras.ops.image.map_coordinates( 91 padded_img, 92 coords_clipped, 93 order=order, 94 fill_mode="nearest", # We'll override invalid positions below. 95 fill_value=fill_value 96 ) 97 # Replace values at invalid positions with fill_value. 98 valid = ops.cast(valid, "float32") # shape: (N,) 99 sampled = valid * sampled + (1 - valid) * fill_value 100 else: 101 # For nonconstant fill modes, we use the coordinates as is. 102 sampled = keras.ops.image.map_coordinates( 103 channel_img, 104 coords_flat, 105 order=order, 106 fill_mode=fill_mode, 107 fill_value=fill_value 108 ) 109 sampled_reshaped = ops.reshape(sampled, (H_out, W_out)) 110 channels_out.append(sampled_reshaped) 111 batch_out = ops.stack(channels_out, axis=-1) # (H_out, W_out, C) 112 outputs_list.append(batch_out) 113 114 outputs = ops.stack(outputs_list, axis=0) # (B, H_out, W_out, C) 115 return outputs
Applies a spatial transformation to the input tensor using a sampling grid.
This function performs image warping by sampling the inputs
tensor at
specified grid
coordinates using interpolation. The grid values are expected
to be normalized to the range [-1, 1], where -1 and 1 correspond to the
top-left and bottom-right corners respectively.
Parameters:
inputs (tf.Tensor): Input tensor of shape (B, H, W, C), representing a batch
of images.
grid (tf.Tensor): Sampling grid of shape (B, H_out, W_out, 2), containing
normalized coordinates (x, y) to sample from the input.
order (int): Interpolation order. 1 for bilinear, 0 for nearest neighbor.
fill_mode (str): Points outside the boundaries are filled according to this
mode. Supported: "constant", "nearest", "reflect", etc.
fill_value (float): Value to use for points sampled outside the boundaries
when fill_mode="constant"
.
batch (int, optional): Manually specify batch size if it cannot be inferred
from inputs (e.g., in a tracing context).
Returns: tf.Tensor: Warped output tensor of shape (B, H_out, W_out, C), where each image has been resampled using the provided grid.
117def generate_grid(shape): 118 119 """ 120 Generates a normalized 2D sampling grid for spatial transformations. 121 122 The grid contains coordinates in the range [-1, 1], where: 123 - x varies from -1 (left) to 1 (right) 124 - y varies from 1 (top) to -1 (bottom), following image coordinates 125 126 This grid is used for resampling operations like spatial warping or 127 optical flow-based motion transfer. 128 129 Parameters: 130 shape (tuple): A 3-tuple (batch_size, height, width) specifying the 131 desired grid shape. 132 133 Returns: 134 tf.Tensor: A grid tensor of shape (batch_size, height, width, 2), where 135 each position contains (x, y) coordinates normalized to [-1, 1]. 136 """ 137 batch_size, h, w = shape 138 139 # Generate normalized coordinates 140 x = ops.linspace(-1.0, 1.0, w) # Linearly spaced x-coordinates 141 y = ops.linspace(1.0, -1.0, h) # Linearly spaced y-coordinates (flip y-axis) 142 143 # Create meshgrid 144 X, Y = ops.meshgrid(x, y, indexing='xy') # Shape (h, w) 145 146 # Stack X and Y into (h, w, 2) 147 grid = ops.stack([X, Y], axis=-1) 148 149 # Expand for batch dimension (batch, h, w, 2) 150 grid = ops.expand_dims(grid, axis=0) 151 grid = ops.repeat(grid, batch_size, axis=0) 152 153 return grid # Shape: (batch_size, h, w, 2)
Generates a normalized 2D sampling grid for spatial transformations.
The grid contains coordinates in the range [-1, 1], where:
- x varies from -1 (left) to 1 (right)
- y varies from 1 (top) to -1 (bottom), following image coordinates
This grid is used for resampling operations like spatial warping or optical flow-based motion transfer.
Parameters: shape (tuple): A 3-tuple (batch_size, height, width) specifying the desired grid shape.
Returns: tf.Tensor: A grid tensor of shape (batch_size, height, width, 2), where each position contains (x, y) coordinates normalized to [-1, 1].
157def grid_to_gaussian(grid, variance=0.01): 158 159 """ 160 Converts a coordinate grid into a 2D isotropic Gaussian heatmap. 161 162 Each point in the grid is interpreted as a coordinate offset from the center, 163 and the function computes the corresponding Gaussian activation based on 164 the squared distance from the origin. 165 166 This is commonly used to convert keypoint offsets or sampling grids into 167 soft spatial attention maps. 168 169 Parameters: 170 grid (tf.Tensor): Tensor of shape (batch, h, w, 2), where each (x, y) pair 171 represents an offset from the center. 172 variance (float): Scalar controlling the spread of the Gaussian (default 0.01). 173 Smaller values produce sharper peaks. 174 175 Returns: 176 tf.Tensor: A tensor of shape (batch, h, w, 1) representing the Gaussian 177 probability distribution over the grid. 178 """ 179 180 squared_distance = ops.sum(ops.square(grid), axis=-1, keepdims=True) # x^2 + y^2 181 gaussian = ops.exp(-0.5 * squared_distance / variance) # Apply Gaussian formula 182 return gaussian # Shape: (batch, h, w, 1)
Converts a coordinate grid into a 2D isotropic Gaussian heatmap.
Each point in the grid is interpreted as a coordinate offset from the center, and the function computes the corresponding Gaussian activation based on the squared distance from the origin.
This is commonly used to convert keypoint offsets or sampling grids into soft spatial attention maps.
Parameters: grid (tf.Tensor): Tensor of shape (batch, h, w, 2), where each (x, y) pair represents an offset from the center. variance (float): Scalar controlling the spread of the Gaussian (default 0.01). Smaller values produce sharper peaks.
Returns: tf.Tensor: A tensor of shape (batch, h, w, 1) representing the Gaussian probability distribution over the grid.
185def kp2gaussian(keypoints, variance, image_shape, batch_size=None): 186 """ 187 Converts normalized keypoint coordinates into 2D Gaussian heatmaps. 188 189 Parameters: 190 keypoints (tf.Tensor): Tensor of shape (B, n, 2), where each keypoint 191 is (x, y) in normalized coordinates. 192 variance (float): Scalar controlling the spread of the Gaussian. 193 image_shape (tuple): Tuple (B, h, w) specifying the output heatmap dimensions. 194 batch_size (int, optional): Optional override for batch size, used when 195 tracing or shape inference is needed. 196 197 Returns: 198 tf.Tensor: A tensor of shape (B, n, h, w) containing a Gaussian heatmap 199 for each keypoint. 200 """ 201 if batch_size is None: 202 B, h, w = image_shape 203 else: 204 B = batch_size 205 _, h, w = image_shape 206 207 grid = generate_grid((B, h, w)) # (B, h, w, 2) 208 grid_exp = ops.expand_dims(grid, axis=1) # (B, 1, h, w, 2) 209 kp_exp = ops.expand_dims(ops.expand_dims(keypoints, axis=2), axis=2) # (B, n, 1, 1, 2) 210 211 delta = grid_exp - kp_exp # (B, n, h, w, 2) 212 squared_distance = ops.sum(ops.square(delta), axis=-1, keepdims=True) # (B, n, h, w, 1) 213 gaussian = ops.exp(-0.5 * squared_distance / variance) # (B, n, h, w, 1) 214 215 return ops.squeeze(gaussian, axis=-1) # (B, n, h, w)
Converts normalized keypoint coordinates into 2D Gaussian heatmaps.
Parameters: keypoints (tf.Tensor): Tensor of shape (B, n, 2), where each keypoint is (x, y) in normalized coordinates. variance (float): Scalar controlling the spread of the Gaussian. image_shape (tuple): Tuple (B, h, w) specifying the output heatmap dimensions. batch_size (int, optional): Optional override for batch size, used when tracing or shape inference is needed.
Returns: tf.Tensor: A tensor of shape (B, n, h, w) containing a Gaussian heatmap for each keypoint.
218def get_keypoint_coordinates(image, grid): 219 """ 220 Computes the keypoint coordinates from a weighted image and spatial grid. 221 222 This function estimates keypoint positions by computing the spatial 223 expectation over the grid, weighted by pixel values in the input image. 224 225 Parameters: 226 image (tf.Tensor): Tensor of shape (B, H, W, C), typically a heatmap 227 or attention map per keypoint. 228 grid (tf.Tensor): Tensor of shape (B, H, W, 2), containing normalized 229 (x, y) coordinates in the range [-1, 1]. 230 231 Returns: 232 tf.Tensor: Tensor of shape (B, 2, C) containing keypoint coordinates 233 (x, y) for each channel (usually one channel per keypoint). 234 """ 235 grid_x = ops.expand_dims(grid[..., 0], axis=-1) # (B, H, W, 1) 236 grid_y = ops.expand_dims(grid[..., 1], axis=-1) # (B, H, W, 1) 237 238 latent_x = image * grid_x # (B, H, W, C) 239 latent_y = image * grid_y # (B, H, W, C) 240 241 kp_x = ops.sum(latent_x, axis=[1, 2]) # (B, C) 242 kp_y = ops.sum(latent_y, axis=[1, 2]) # (B, C) 243 244 return ops.stack([kp_x, kp_y], axis=1) # (B, 2, C)
Computes the keypoint coordinates from a weighted image and spatial grid.
This function estimates keypoint positions by computing the spatial expectation over the grid, weighted by pixel values in the input image.
Parameters: image (tf.Tensor): Tensor of shape (B, H, W, C), typically a heatmap or attention map per keypoint. grid (tf.Tensor): Tensor of shape (B, H, W, 2), containing normalized (x, y) coordinates in the range [-1, 1].
Returns: tf.Tensor: Tensor of shape (B, 2, C) containing keypoint coordinates (x, y) for each channel (usually one channel per keypoint).
246def sparse_motion(image_shape, 247 sparse_image_keypoints, 248 driving_image_keypoints, 249 sparse_image_jacobians, 250 driving_image_jacobians): 251 """ 252 Computes a per-keypoint motion field by transforming pixel-wise coordinates 253 from the driving frame to the source (sparse) frame using Jacobian-based warping. 254 255 Parameters: 256 image_shape (tuple): Tuple (B, h, w) representing batch size and output resolution. 257 sparse_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), source keypoints. 258 driving_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), driving keypoints. 259 sparse_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at source keypoints. 260 driving_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at driving keypoints. 261 262 Returns: 263 tf.Tensor: Transformed coordinate grid of shape (B, n, h, w, 2). 264 """ 265 B, h, w = image_shape 266 n = ops.shape(sparse_image_keypoints)[1] 267 268 grid = ops.expand_dims(generate_grid(image_shape), axis=1) # (B, 1, h, w, 2) 269 grid = ops.repeat(grid, n, axis=1) # (B, n, h, w, 2) 270 271 driving_kp = ops.expand_dims(ops.expand_dims(driving_image_keypoints, axis=2), axis=3) # (B, n, 1, 1, 2) 272 diff = grid - driving_kp # (B, n, h, w, 2) 273 274 inv_driving = ops.linalg.inv(driving_image_jacobians + ops.eye(2, 2) * 1e-6) # (B, n, 2, 2) 275 T = ops.matmul(sparse_image_jacobians, inv_driving) # (B, n, 2, 2) 276 277 diff_transformed = ops.einsum('bnij,bnxyj->bnxyi', T, diff) # (B, n, h, w, 2) 278 279 sparse_kp = ops.expand_dims(ops.expand_dims(sparse_image_keypoints, axis=2), axis=3) 280 return diff_transformed + sparse_kp # (B, n, h, w, 2)
Computes a per-keypoint motion field by transforming pixel-wise coordinates from the driving frame to the source (sparse) frame using Jacobian-based warping.
Parameters: image_shape (tuple): Tuple (B, h, w) representing batch size and output resolution. sparse_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), source keypoints. driving_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), driving keypoints. sparse_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at source keypoints. driving_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at driving keypoints.
Returns: tf.Tensor: Transformed coordinate grid of shape (B, n, h, w, 2).
282def apply_sparse_motion(image, motion_field, order=1, fill_mode="constant", fill_value=0, batch_size=None): 283 """ 284 Applies a sparse motion field to an image using interpolation. 285 286 Parameters: 287 image (Tensor): Input tensor of shape (B, h, w, 1) or (B, h, w). 288 motion_field (Tensor): Motion field of shape (B, n, h, w, 2), one per keypoint. 289 order (int): Interpolation order. 290 fill_mode (str): Fill mode for sampling. 291 fill_value (float): Fill value for constant mode. 292 batch_size (int, optional): Batch size override. 293 294 Returns: 295 Tensor: Output of shape (B, h, w, n), one warped image per keypoint. 296 """ 297 if len(image.shape) == 3: 298 image = ops.expand_dims(image, axis=-1) 299 300 B = batch_size if batch_size is not None else image.shape[0] 301 n = motion_field.shape[1] 302 303 deformed = [] 304 for i in range(n): 305 grid_i = motion_field[:, i] 306 warped = grid_transform(image, grid_i, order=order, fill_mode=fill_mode, fill_value=fill_value, batch=B) 307 deformed.append(warped) 308 309 return ops.concatenate(deformed, axis=-1)
Applies a sparse motion field to an image using interpolation.
Parameters: image (Tensor): Input tensor of shape (B, h, w, 1) or (B, h, w). motion_field (Tensor): Motion field of shape (B, n, h, w, 2), one per keypoint. order (int): Interpolation order. fill_mode (str): Fill mode for sampling. fill_value (float): Fill value for constant mode. batch_size (int, optional): Batch size override.
Returns: Tensor: Output of shape (B, h, w, n), one warped image per keypoint.