src.utils.model_utils

  1import keras
  2import keras.ops as ops
  3import numpy as np
  4import matplotlib.pyplot as plt
  5from tqdm import tqdm
  6
  7def grid_transform(inputs, grid, order=1, fill_mode="constant", fill_value=0, batch = None):
  8    """
  9    Applies a spatial transformation to the input tensor using a sampling grid.
 10
 11    This function performs image warping by sampling the `inputs` tensor at 
 12    specified `grid` coordinates using interpolation. The grid values are expected 
 13    to be normalized to the range [-1, 1], where -1 and 1 correspond to the 
 14    top-left and bottom-right corners respectively.
 15
 16    Parameters:
 17        inputs (tf.Tensor): Input tensor of shape (B, H, W, C), representing a batch 
 18                            of images.
 19        grid (tf.Tensor): Sampling grid of shape (B, H_out, W_out, 2), containing 
 20                          normalized coordinates (x, y) to sample from the input.
 21        order (int): Interpolation order. 1 for bilinear, 0 for nearest neighbor.
 22        fill_mode (str): Points outside the boundaries are filled according to this 
 23                         mode. Supported: "constant", "nearest", "reflect", etc.
 24        fill_value (float): Value to use for points sampled outside the boundaries 
 25                            when `fill_mode="constant"`.
 26        batch (int, optional): Manually specify batch size if it cannot be inferred 
 27                               from inputs (e.g., in a tracing context).
 28
 29    Returns:
 30        tf.Tensor: Warped output tensor of shape (B, H_out, W_out, C), where each 
 31                   image has been resampled using the provided grid.
 32    """
 33
 34    # Assume inputs has static shape (B, H, W, C)
 35
 36    if batch is not None:
 37        B = batch
 38    else:
 39        B = inputs.shape[0]
 40    H = inputs.shape[1]
 41    W = inputs.shape[2]
 42    C = inputs.shape[3]
 43
 44    # Dynamic output grid dims
 45    grid_shape = ops.shape(grid)
 46    H_out = grid_shape[1]
 47    W_out = grid_shape[2]
 48
 49    # Convert normalized grid coordinates to pixel coordinates.
 50    # grid[...,0] is x, grid[...,1] is y.
 51    x = (grid[..., 0] + 1.0) * (ops.cast(W, "float32") - 1) / 2.0
 52    y = (grid[..., 1] + 1.0) * (ops.cast(H, "float32") - 1) / 2.0
 53
 54    outputs_list = []
 55
 56    for b in range(B):
 57        channels_out = []
 58        # Swap coordinate order: (y, x) for image indexing.
 59        coords = ops.stack([y[b], x[b]], axis=-1)  # shape: (H_out, W_out, 2)
 60        coords_flat = ops.reshape(coords, (-1, 2))   # shape: (H_out*W_out, 2)
 61        # Transpose to shape (2, N)
 62        coords_flat = ops.transpose(coords_flat, [1, 0])
 63
 64        for c in range(C):
 65            channel_img = inputs[b, :, :, c]  # shape: (H, W)
 66            if fill_mode == "constant":
 67                # Pad image with one pixel on each side.
 68                padded_img = ops.pad(channel_img, [[1, 1], [1, 1]], constant_values=fill_value)
 69                padded_H = H + 2
 70                padded_W = W + 2
 71                # Adjust coordinates by +1 to index into padded image.
 72                coords_adj = coords_flat + 1.0
 73                # For bilinear interpolation, we need floor and ceil of each coordinate.
 74                y_coords = coords_adj[0]
 75                x_coords = coords_adj[1]
 76                y0 = ops.floor(y_coords)
 77                y1 = ops.ceil(y_coords)
 78                x0 = ops.floor(x_coords)
 79                x1 = ops.ceil(x_coords)
 80                # valid if both floor and ceil are within [0, padded_dim - 1].
 81                valid = (y0 >= 0) & (y1 < padded_H) & (x0 >= 0) & (x1 < padded_W)
 82
 83                # For sampling, clip coordinates into valid range.
 84                y_clip = ops.clip(y_coords, 0, padded_H - 1)
 85                x_clip = ops.clip(x_coords, 0, padded_W - 1)
 86                coords_clipped = ops.stack([y_clip, x_clip], axis=0)
 87
 88                # Use map_coordinates with a safe fill_mode.
 89                sampled = keras.ops.image.map_coordinates(
 90                    padded_img,
 91                    coords_clipped,
 92                    order=order,
 93                    fill_mode="nearest",  # We'll override invalid positions below.
 94                    fill_value=fill_value
 95                )
 96                # Replace values at invalid positions with fill_value.
 97                valid = ops.cast(valid, "float32")  # shape: (N,)
 98                sampled = valid * sampled + (1 - valid) * fill_value
 99            else:
100                # For nonconstant fill modes, we use the coordinates as is.
101                sampled = keras.ops.image.map_coordinates(
102                    channel_img,
103                    coords_flat,
104                    order=order,
105                    fill_mode=fill_mode,
106                    fill_value=fill_value
107                )
108            sampled_reshaped = ops.reshape(sampled, (H_out, W_out))
109            channels_out.append(sampled_reshaped)
110        batch_out = ops.stack(channels_out, axis=-1)  # (H_out, W_out, C)
111        outputs_list.append(batch_out)
112
113    outputs = ops.stack(outputs_list, axis=0)  # (B, H_out, W_out, C)
114    return outputs
115
116def generate_grid(shape):
117
118    """
119    Generates a normalized 2D sampling grid for spatial transformations.
120
121    The grid contains coordinates in the range [-1, 1], where:
122      - x varies from -1 (left) to 1 (right)
123      - y varies from 1 (top) to -1 (bottom), following image coordinates
124
125    This grid is used for resampling operations like spatial warping or 
126    optical flow-based motion transfer.
127
128    Parameters:
129        shape (tuple): A 3-tuple (batch_size, height, width) specifying the 
130                       desired grid shape.
131
132    Returns:
133        tf.Tensor: A grid tensor of shape (batch_size, height, width, 2), where 
134                   each position contains (x, y) coordinates normalized to [-1, 1].
135    """
136    batch_size, h, w = shape
137
138    # Generate normalized coordinates
139    x = ops.linspace(-1.0, 1.0, w)  # Linearly spaced x-coordinates
140    y = ops.linspace(1.0, -1.0, h)  # Linearly spaced y-coordinates (flip y-axis)
141
142    # Create meshgrid
143    X, Y = ops.meshgrid(x, y, indexing='xy')  # Shape (h, w)
144
145    # Stack X and Y into (h, w, 2)
146    grid = ops.stack([X, Y], axis=-1)
147
148    # Expand for batch dimension (batch, h, w, 2)
149    grid = ops.expand_dims(grid, axis=0)
150    grid = ops.repeat(grid, batch_size, axis=0)
151
152    return grid  # Shape: (batch_size, h, w, 2)
153
154
155
156def grid_to_gaussian(grid, variance=0.01):
157
158    """
159    Converts a coordinate grid into a 2D isotropic Gaussian heatmap.
160
161    Each point in the grid is interpreted as a coordinate offset from the center,
162    and the function computes the corresponding Gaussian activation based on 
163    the squared distance from the origin.
164
165    This is commonly used to convert keypoint offsets or sampling grids into 
166    soft spatial attention maps.
167
168    Parameters:
169        grid (tf.Tensor): Tensor of shape (batch, h, w, 2), where each (x, y) pair 
170                          represents an offset from the center.
171        variance (float): Scalar controlling the spread of the Gaussian (default 0.01). 
172                          Smaller values produce sharper peaks.
173
174    Returns:
175        tf.Tensor: A tensor of shape (batch, h, w, 1) representing the Gaussian 
176                   probability distribution over the grid.
177    """
178
179    squared_distance = ops.sum(ops.square(grid), axis=-1, keepdims=True)  # x^2 + y^2
180    gaussian = ops.exp(-0.5 * squared_distance / variance)  # Apply Gaussian formula
181    return gaussian  # Shape: (batch, h, w, 1)
182
183
184def kp2gaussian(keypoints, variance, image_shape, batch_size=None):
185    """
186    Converts normalized keypoint coordinates into 2D Gaussian heatmaps.
187
188    Parameters:
189        keypoints (tf.Tensor): Tensor of shape (B, n, 2), where each keypoint 
190                               is (x, y) in normalized coordinates.
191        variance (float): Scalar controlling the spread of the Gaussian.
192        image_shape (tuple): Tuple (B, h, w) specifying the output heatmap dimensions.
193        batch_size (int, optional): Optional override for batch size, used when 
194                                    tracing or shape inference is needed.
195
196    Returns:
197        tf.Tensor: A tensor of shape (B, n, h, w) containing a Gaussian heatmap 
198                   for each keypoint.
199    """
200    if batch_size is None:
201        B, h, w = image_shape
202    else:
203        B = batch_size
204        _, h, w = image_shape
205
206    grid = generate_grid((B, h, w))                           # (B, h, w, 2)
207    grid_exp = ops.expand_dims(grid, axis=1)                 # (B, 1, h, w, 2)
208    kp_exp = ops.expand_dims(ops.expand_dims(keypoints, axis=2), axis=2)  # (B, n, 1, 1, 2)
209
210    delta = grid_exp - kp_exp                                # (B, n, h, w, 2)
211    squared_distance = ops.sum(ops.square(delta), axis=-1, keepdims=True)  # (B, n, h, w, 1)
212    gaussian = ops.exp(-0.5 * squared_distance / variance)   # (B, n, h, w, 1)
213
214    return ops.squeeze(gaussian, axis=-1)                    # (B, n, h, w)
215
216
217def get_keypoint_coordinates(image, grid):
218    """
219    Computes the keypoint coordinates from a weighted image and spatial grid.
220
221    This function estimates keypoint positions by computing the spatial 
222    expectation over the grid, weighted by pixel values in the input image.
223
224    Parameters:
225        image (tf.Tensor): Tensor of shape (B, H, W, C), typically a heatmap 
226                           or attention map per keypoint.
227        grid (tf.Tensor): Tensor of shape (B, H, W, 2), containing normalized 
228                          (x, y) coordinates in the range [-1, 1].
229
230    Returns:
231        tf.Tensor: Tensor of shape (B, 2, C) containing keypoint coordinates 
232                   (x, y) for each channel (usually one channel per keypoint).
233    """
234    grid_x = ops.expand_dims(grid[..., 0], axis=-1)  # (B, H, W, 1)
235    grid_y = ops.expand_dims(grid[..., 1], axis=-1)  # (B, H, W, 1)
236
237    latent_x = image * grid_x  # (B, H, W, C)
238    latent_y = image * grid_y  # (B, H, W, C)
239
240    kp_x = ops.sum(latent_x, axis=[1, 2])  # (B, C)
241    kp_y = ops.sum(latent_y, axis=[1, 2])  # (B, C)
242
243    return ops.stack([kp_x, kp_y], axis=1)  # (B, 2, C)
244
245def sparse_motion(image_shape,
246                  sparse_image_keypoints,
247                  driving_image_keypoints,
248                  sparse_image_jacobians,
249                  driving_image_jacobians):
250    """
251    Computes a per-keypoint motion field by transforming pixel-wise coordinates 
252    from the driving frame to the source (sparse) frame using Jacobian-based warping.
253
254    Parameters:
255        image_shape (tuple): Tuple (B, h, w) representing batch size and output resolution.
256        sparse_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), source keypoints.
257        driving_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), driving keypoints.
258        sparse_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at source keypoints.
259        driving_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at driving keypoints.
260
261    Returns:
262        tf.Tensor: Transformed coordinate grid of shape (B, n, h, w, 2).
263    """
264    B, h, w = image_shape
265    n = ops.shape(sparse_image_keypoints)[1]
266
267    grid = ops.expand_dims(generate_grid(image_shape), axis=1)  # (B, 1, h, w, 2)
268    grid = ops.repeat(grid, n, axis=1)                          # (B, n, h, w, 2)
269
270    driving_kp = ops.expand_dims(ops.expand_dims(driving_image_keypoints, axis=2), axis=3)  # (B, n, 1, 1, 2)
271    diff = grid - driving_kp  # (B, n, h, w, 2)
272
273    inv_driving = ops.linalg.inv(driving_image_jacobians + ops.eye(2, 2) * 1e-6)  # (B, n, 2, 2)
274    T = ops.matmul(sparse_image_jacobians, inv_driving)  # (B, n, 2, 2)
275
276    diff_transformed = ops.einsum('bnij,bnxyj->bnxyi', T, diff)  # (B, n, h, w, 2)
277
278    sparse_kp = ops.expand_dims(ops.expand_dims(sparse_image_keypoints, axis=2), axis=3)
279    return diff_transformed + sparse_kp  # (B, n, h, w, 2)
280
281def apply_sparse_motion(image, motion_field, order=1, fill_mode="constant", fill_value=0, batch_size=None):
282    """
283    Applies a sparse motion field to an image using interpolation.
284
285    Parameters:
286        image (Tensor): Input tensor of shape (B, h, w, 1) or (B, h, w).
287        motion_field (Tensor): Motion field of shape (B, n, h, w, 2), one per keypoint.
288        order (int): Interpolation order.
289        fill_mode (str): Fill mode for sampling.
290        fill_value (float): Fill value for constant mode.
291        batch_size (int, optional): Batch size override.
292
293    Returns:
294        Tensor: Output of shape (B, h, w, n), one warped image per keypoint.
295    """
296    if len(image.shape) == 3:
297        image = ops.expand_dims(image, axis=-1)
298
299    B = batch_size if batch_size is not None else image.shape[0]
300    n = motion_field.shape[1]
301
302    deformed = []
303    for i in range(n):
304        grid_i = motion_field[:, i]
305        warped = grid_transform(image, grid_i, order=order, fill_mode=fill_mode, fill_value=fill_value, batch=B)
306        deformed.append(warped)
307
308    return ops.concatenate(deformed, axis=-1)
def grid_transform( inputs, grid, order=1, fill_mode='constant', fill_value=0, batch=None):
  8def grid_transform(inputs, grid, order=1, fill_mode="constant", fill_value=0, batch = None):
  9    """
 10    Applies a spatial transformation to the input tensor using a sampling grid.
 11
 12    This function performs image warping by sampling the `inputs` tensor at 
 13    specified `grid` coordinates using interpolation. The grid values are expected 
 14    to be normalized to the range [-1, 1], where -1 and 1 correspond to the 
 15    top-left and bottom-right corners respectively.
 16
 17    Parameters:
 18        inputs (tf.Tensor): Input tensor of shape (B, H, W, C), representing a batch 
 19                            of images.
 20        grid (tf.Tensor): Sampling grid of shape (B, H_out, W_out, 2), containing 
 21                          normalized coordinates (x, y) to sample from the input.
 22        order (int): Interpolation order. 1 for bilinear, 0 for nearest neighbor.
 23        fill_mode (str): Points outside the boundaries are filled according to this 
 24                         mode. Supported: "constant", "nearest", "reflect", etc.
 25        fill_value (float): Value to use for points sampled outside the boundaries 
 26                            when `fill_mode="constant"`.
 27        batch (int, optional): Manually specify batch size if it cannot be inferred 
 28                               from inputs (e.g., in a tracing context).
 29
 30    Returns:
 31        tf.Tensor: Warped output tensor of shape (B, H_out, W_out, C), where each 
 32                   image has been resampled using the provided grid.
 33    """
 34
 35    # Assume inputs has static shape (B, H, W, C)
 36
 37    if batch is not None:
 38        B = batch
 39    else:
 40        B = inputs.shape[0]
 41    H = inputs.shape[1]
 42    W = inputs.shape[2]
 43    C = inputs.shape[3]
 44
 45    # Dynamic output grid dims
 46    grid_shape = ops.shape(grid)
 47    H_out = grid_shape[1]
 48    W_out = grid_shape[2]
 49
 50    # Convert normalized grid coordinates to pixel coordinates.
 51    # grid[...,0] is x, grid[...,1] is y.
 52    x = (grid[..., 0] + 1.0) * (ops.cast(W, "float32") - 1) / 2.0
 53    y = (grid[..., 1] + 1.0) * (ops.cast(H, "float32") - 1) / 2.0
 54
 55    outputs_list = []
 56
 57    for b in range(B):
 58        channels_out = []
 59        # Swap coordinate order: (y, x) for image indexing.
 60        coords = ops.stack([y[b], x[b]], axis=-1)  # shape: (H_out, W_out, 2)
 61        coords_flat = ops.reshape(coords, (-1, 2))   # shape: (H_out*W_out, 2)
 62        # Transpose to shape (2, N)
 63        coords_flat = ops.transpose(coords_flat, [1, 0])
 64
 65        for c in range(C):
 66            channel_img = inputs[b, :, :, c]  # shape: (H, W)
 67            if fill_mode == "constant":
 68                # Pad image with one pixel on each side.
 69                padded_img = ops.pad(channel_img, [[1, 1], [1, 1]], constant_values=fill_value)
 70                padded_H = H + 2
 71                padded_W = W + 2
 72                # Adjust coordinates by +1 to index into padded image.
 73                coords_adj = coords_flat + 1.0
 74                # For bilinear interpolation, we need floor and ceil of each coordinate.
 75                y_coords = coords_adj[0]
 76                x_coords = coords_adj[1]
 77                y0 = ops.floor(y_coords)
 78                y1 = ops.ceil(y_coords)
 79                x0 = ops.floor(x_coords)
 80                x1 = ops.ceil(x_coords)
 81                # valid if both floor and ceil are within [0, padded_dim - 1].
 82                valid = (y0 >= 0) & (y1 < padded_H) & (x0 >= 0) & (x1 < padded_W)
 83
 84                # For sampling, clip coordinates into valid range.
 85                y_clip = ops.clip(y_coords, 0, padded_H - 1)
 86                x_clip = ops.clip(x_coords, 0, padded_W - 1)
 87                coords_clipped = ops.stack([y_clip, x_clip], axis=0)
 88
 89                # Use map_coordinates with a safe fill_mode.
 90                sampled = keras.ops.image.map_coordinates(
 91                    padded_img,
 92                    coords_clipped,
 93                    order=order,
 94                    fill_mode="nearest",  # We'll override invalid positions below.
 95                    fill_value=fill_value
 96                )
 97                # Replace values at invalid positions with fill_value.
 98                valid = ops.cast(valid, "float32")  # shape: (N,)
 99                sampled = valid * sampled + (1 - valid) * fill_value
100            else:
101                # For nonconstant fill modes, we use the coordinates as is.
102                sampled = keras.ops.image.map_coordinates(
103                    channel_img,
104                    coords_flat,
105                    order=order,
106                    fill_mode=fill_mode,
107                    fill_value=fill_value
108                )
109            sampled_reshaped = ops.reshape(sampled, (H_out, W_out))
110            channels_out.append(sampled_reshaped)
111        batch_out = ops.stack(channels_out, axis=-1)  # (H_out, W_out, C)
112        outputs_list.append(batch_out)
113
114    outputs = ops.stack(outputs_list, axis=0)  # (B, H_out, W_out, C)
115    return outputs

Applies a spatial transformation to the input tensor using a sampling grid.

This function performs image warping by sampling the inputs tensor at specified grid coordinates using interpolation. The grid values are expected to be normalized to the range [-1, 1], where -1 and 1 correspond to the top-left and bottom-right corners respectively.

Parameters: inputs (tf.Tensor): Input tensor of shape (B, H, W, C), representing a batch of images. grid (tf.Tensor): Sampling grid of shape (B, H_out, W_out, 2), containing normalized coordinates (x, y) to sample from the input. order (int): Interpolation order. 1 for bilinear, 0 for nearest neighbor. fill_mode (str): Points outside the boundaries are filled according to this mode. Supported: "constant", "nearest", "reflect", etc. fill_value (float): Value to use for points sampled outside the boundaries when fill_mode="constant". batch (int, optional): Manually specify batch size if it cannot be inferred from inputs (e.g., in a tracing context).

Returns: tf.Tensor: Warped output tensor of shape (B, H_out, W_out, C), where each image has been resampled using the provided grid.

def generate_grid(shape):
117def generate_grid(shape):
118
119    """
120    Generates a normalized 2D sampling grid for spatial transformations.
121
122    The grid contains coordinates in the range [-1, 1], where:
123      - x varies from -1 (left) to 1 (right)
124      - y varies from 1 (top) to -1 (bottom), following image coordinates
125
126    This grid is used for resampling operations like spatial warping or 
127    optical flow-based motion transfer.
128
129    Parameters:
130        shape (tuple): A 3-tuple (batch_size, height, width) specifying the 
131                       desired grid shape.
132
133    Returns:
134        tf.Tensor: A grid tensor of shape (batch_size, height, width, 2), where 
135                   each position contains (x, y) coordinates normalized to [-1, 1].
136    """
137    batch_size, h, w = shape
138
139    # Generate normalized coordinates
140    x = ops.linspace(-1.0, 1.0, w)  # Linearly spaced x-coordinates
141    y = ops.linspace(1.0, -1.0, h)  # Linearly spaced y-coordinates (flip y-axis)
142
143    # Create meshgrid
144    X, Y = ops.meshgrid(x, y, indexing='xy')  # Shape (h, w)
145
146    # Stack X and Y into (h, w, 2)
147    grid = ops.stack([X, Y], axis=-1)
148
149    # Expand for batch dimension (batch, h, w, 2)
150    grid = ops.expand_dims(grid, axis=0)
151    grid = ops.repeat(grid, batch_size, axis=0)
152
153    return grid  # Shape: (batch_size, h, w, 2)

Generates a normalized 2D sampling grid for spatial transformations.

The grid contains coordinates in the range [-1, 1], where:

  • x varies from -1 (left) to 1 (right)
  • y varies from 1 (top) to -1 (bottom), following image coordinates

This grid is used for resampling operations like spatial warping or optical flow-based motion transfer.

Parameters: shape (tuple): A 3-tuple (batch_size, height, width) specifying the desired grid shape.

Returns: tf.Tensor: A grid tensor of shape (batch_size, height, width, 2), where each position contains (x, y) coordinates normalized to [-1, 1].

def grid_to_gaussian(grid, variance=0.01):
157def grid_to_gaussian(grid, variance=0.01):
158
159    """
160    Converts a coordinate grid into a 2D isotropic Gaussian heatmap.
161
162    Each point in the grid is interpreted as a coordinate offset from the center,
163    and the function computes the corresponding Gaussian activation based on 
164    the squared distance from the origin.
165
166    This is commonly used to convert keypoint offsets or sampling grids into 
167    soft spatial attention maps.
168
169    Parameters:
170        grid (tf.Tensor): Tensor of shape (batch, h, w, 2), where each (x, y) pair 
171                          represents an offset from the center.
172        variance (float): Scalar controlling the spread of the Gaussian (default 0.01). 
173                          Smaller values produce sharper peaks.
174
175    Returns:
176        tf.Tensor: A tensor of shape (batch, h, w, 1) representing the Gaussian 
177                   probability distribution over the grid.
178    """
179
180    squared_distance = ops.sum(ops.square(grid), axis=-1, keepdims=True)  # x^2 + y^2
181    gaussian = ops.exp(-0.5 * squared_distance / variance)  # Apply Gaussian formula
182    return gaussian  # Shape: (batch, h, w, 1)

Converts a coordinate grid into a 2D isotropic Gaussian heatmap.

Each point in the grid is interpreted as a coordinate offset from the center, and the function computes the corresponding Gaussian activation based on the squared distance from the origin.

This is commonly used to convert keypoint offsets or sampling grids into soft spatial attention maps.

Parameters: grid (tf.Tensor): Tensor of shape (batch, h, w, 2), where each (x, y) pair represents an offset from the center. variance (float): Scalar controlling the spread of the Gaussian (default 0.01). Smaller values produce sharper peaks.

Returns: tf.Tensor: A tensor of shape (batch, h, w, 1) representing the Gaussian probability distribution over the grid.

def kp2gaussian(keypoints, variance, image_shape, batch_size=None):
185def kp2gaussian(keypoints, variance, image_shape, batch_size=None):
186    """
187    Converts normalized keypoint coordinates into 2D Gaussian heatmaps.
188
189    Parameters:
190        keypoints (tf.Tensor): Tensor of shape (B, n, 2), where each keypoint 
191                               is (x, y) in normalized coordinates.
192        variance (float): Scalar controlling the spread of the Gaussian.
193        image_shape (tuple): Tuple (B, h, w) specifying the output heatmap dimensions.
194        batch_size (int, optional): Optional override for batch size, used when 
195                                    tracing or shape inference is needed.
196
197    Returns:
198        tf.Tensor: A tensor of shape (B, n, h, w) containing a Gaussian heatmap 
199                   for each keypoint.
200    """
201    if batch_size is None:
202        B, h, w = image_shape
203    else:
204        B = batch_size
205        _, h, w = image_shape
206
207    grid = generate_grid((B, h, w))                           # (B, h, w, 2)
208    grid_exp = ops.expand_dims(grid, axis=1)                 # (B, 1, h, w, 2)
209    kp_exp = ops.expand_dims(ops.expand_dims(keypoints, axis=2), axis=2)  # (B, n, 1, 1, 2)
210
211    delta = grid_exp - kp_exp                                # (B, n, h, w, 2)
212    squared_distance = ops.sum(ops.square(delta), axis=-1, keepdims=True)  # (B, n, h, w, 1)
213    gaussian = ops.exp(-0.5 * squared_distance / variance)   # (B, n, h, w, 1)
214
215    return ops.squeeze(gaussian, axis=-1)                    # (B, n, h, w)

Converts normalized keypoint coordinates into 2D Gaussian heatmaps.

Parameters: keypoints (tf.Tensor): Tensor of shape (B, n, 2), where each keypoint is (x, y) in normalized coordinates. variance (float): Scalar controlling the spread of the Gaussian. image_shape (tuple): Tuple (B, h, w) specifying the output heatmap dimensions. batch_size (int, optional): Optional override for batch size, used when tracing or shape inference is needed.

Returns: tf.Tensor: A tensor of shape (B, n, h, w) containing a Gaussian heatmap for each keypoint.

def get_keypoint_coordinates(image, grid):
218def get_keypoint_coordinates(image, grid):
219    """
220    Computes the keypoint coordinates from a weighted image and spatial grid.
221
222    This function estimates keypoint positions by computing the spatial 
223    expectation over the grid, weighted by pixel values in the input image.
224
225    Parameters:
226        image (tf.Tensor): Tensor of shape (B, H, W, C), typically a heatmap 
227                           or attention map per keypoint.
228        grid (tf.Tensor): Tensor of shape (B, H, W, 2), containing normalized 
229                          (x, y) coordinates in the range [-1, 1].
230
231    Returns:
232        tf.Tensor: Tensor of shape (B, 2, C) containing keypoint coordinates 
233                   (x, y) for each channel (usually one channel per keypoint).
234    """
235    grid_x = ops.expand_dims(grid[..., 0], axis=-1)  # (B, H, W, 1)
236    grid_y = ops.expand_dims(grid[..., 1], axis=-1)  # (B, H, W, 1)
237
238    latent_x = image * grid_x  # (B, H, W, C)
239    latent_y = image * grid_y  # (B, H, W, C)
240
241    kp_x = ops.sum(latent_x, axis=[1, 2])  # (B, C)
242    kp_y = ops.sum(latent_y, axis=[1, 2])  # (B, C)
243
244    return ops.stack([kp_x, kp_y], axis=1)  # (B, 2, C)

Computes the keypoint coordinates from a weighted image and spatial grid.

This function estimates keypoint positions by computing the spatial expectation over the grid, weighted by pixel values in the input image.

Parameters: image (tf.Tensor): Tensor of shape (B, H, W, C), typically a heatmap or attention map per keypoint. grid (tf.Tensor): Tensor of shape (B, H, W, 2), containing normalized (x, y) coordinates in the range [-1, 1].

Returns: tf.Tensor: Tensor of shape (B, 2, C) containing keypoint coordinates (x, y) for each channel (usually one channel per keypoint).

def sparse_motion( image_shape, sparse_image_keypoints, driving_image_keypoints, sparse_image_jacobians, driving_image_jacobians):
246def sparse_motion(image_shape,
247                  sparse_image_keypoints,
248                  driving_image_keypoints,
249                  sparse_image_jacobians,
250                  driving_image_jacobians):
251    """
252    Computes a per-keypoint motion field by transforming pixel-wise coordinates 
253    from the driving frame to the source (sparse) frame using Jacobian-based warping.
254
255    Parameters:
256        image_shape (tuple): Tuple (B, h, w) representing batch size and output resolution.
257        sparse_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), source keypoints.
258        driving_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), driving keypoints.
259        sparse_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at source keypoints.
260        driving_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at driving keypoints.
261
262    Returns:
263        tf.Tensor: Transformed coordinate grid of shape (B, n, h, w, 2).
264    """
265    B, h, w = image_shape
266    n = ops.shape(sparse_image_keypoints)[1]
267
268    grid = ops.expand_dims(generate_grid(image_shape), axis=1)  # (B, 1, h, w, 2)
269    grid = ops.repeat(grid, n, axis=1)                          # (B, n, h, w, 2)
270
271    driving_kp = ops.expand_dims(ops.expand_dims(driving_image_keypoints, axis=2), axis=3)  # (B, n, 1, 1, 2)
272    diff = grid - driving_kp  # (B, n, h, w, 2)
273
274    inv_driving = ops.linalg.inv(driving_image_jacobians + ops.eye(2, 2) * 1e-6)  # (B, n, 2, 2)
275    T = ops.matmul(sparse_image_jacobians, inv_driving)  # (B, n, 2, 2)
276
277    diff_transformed = ops.einsum('bnij,bnxyj->bnxyi', T, diff)  # (B, n, h, w, 2)
278
279    sparse_kp = ops.expand_dims(ops.expand_dims(sparse_image_keypoints, axis=2), axis=3)
280    return diff_transformed + sparse_kp  # (B, n, h, w, 2)

Computes a per-keypoint motion field by transforming pixel-wise coordinates from the driving frame to the source (sparse) frame using Jacobian-based warping.

Parameters: image_shape (tuple): Tuple (B, h, w) representing batch size and output resolution. sparse_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), source keypoints. driving_image_keypoints (tf.Tensor): Tensor of shape (B, n, 2), driving keypoints. sparse_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at source keypoints. driving_image_jacobians (tf.Tensor): Tensor of shape (B, n, 2, 2), Jacobians at driving keypoints.

Returns: tf.Tensor: Transformed coordinate grid of shape (B, n, h, w, 2).

def apply_sparse_motion( image, motion_field, order=1, fill_mode='constant', fill_value=0, batch_size=None):
282def apply_sparse_motion(image, motion_field, order=1, fill_mode="constant", fill_value=0, batch_size=None):
283    """
284    Applies a sparse motion field to an image using interpolation.
285
286    Parameters:
287        image (Tensor): Input tensor of shape (B, h, w, 1) or (B, h, w).
288        motion_field (Tensor): Motion field of shape (B, n, h, w, 2), one per keypoint.
289        order (int): Interpolation order.
290        fill_mode (str): Fill mode for sampling.
291        fill_value (float): Fill value for constant mode.
292        batch_size (int, optional): Batch size override.
293
294    Returns:
295        Tensor: Output of shape (B, h, w, n), one warped image per keypoint.
296    """
297    if len(image.shape) == 3:
298        image = ops.expand_dims(image, axis=-1)
299
300    B = batch_size if batch_size is not None else image.shape[0]
301    n = motion_field.shape[1]
302
303    deformed = []
304    for i in range(n):
305        grid_i = motion_field[:, i]
306        warped = grid_transform(image, grid_i, order=order, fill_mode=fill_mode, fill_value=fill_value, batch=B)
307        deformed.append(warped)
308
309    return ops.concatenate(deformed, axis=-1)

Applies a sparse motion field to an image using interpolation.

Parameters: image (Tensor): Input tensor of shape (B, h, w, 1) or (B, h, w). motion_field (Tensor): Motion field of shape (B, n, h, w, 2), one per keypoint. order (int): Interpolation order. fill_mode (str): Fill mode for sampling. fill_value (float): Fill value for constant mode. batch_size (int, optional): Batch size override.

Returns: Tensor: Output of shape (B, h, w, n), one warped image per keypoint.