training_utils

This module provides utilities for distributed training with Keras and TensorFlow.

Functions: get_distribution_scope(device_type: str) -> ContextManager Returns a context manager for executing code in a distributed environment.

Usage: To use the get_distribution_scope function, first set the KERAS_BACKEND environment variable to either 'jax' or 'tensorflow'. Then, call the function with the desired device type ('cpu', 'gpu', or 'tpu').

def get_distribution_scope(device_type):

Returns a context manager for executing code in a distributed environment.

This function supports both JAX and TensorFlow backends. For the JAX backend, it supports CPU, GPU, and TPU devices, while for the TensorFlow backend, it supports CPU, GPU, and TPU devices with appropriate distribution strategies.

The context manager returned by this function prints the total number of available devices and the total time taken for the code executed within the context manager.

Args:

  • device_type (str): The type of device to use for distributed training. Can be 'cpu', 'gpu', or 'tpu'.

Returns:

  • A context manager object that can be used to execute code in a distributed environment.

Notes:

  • For the JAX backend, this function uses the jax.distribution module to create a DataParallel distribution for GPU devices. For more information, see the Keras guide on distributed training with JAX: https://keras.io/guides/distribution/
  • For the TensorFlow backend, this function uses the appropriate distribution strategy based on the device type: - For CPU and GPU, it uses tf.distribute.MirroredStrategy - For TPU, it uses tf.distribute.TPUStrategy

  • For more information on distributed training with TensorFlow, see the TensorFlow guide: https://www.tensorflow.org/guide/distributed_training

Raises:

  • ValueError: If an unsupported device type or backend is provided.

Examples:

# JAX backend
    distribute_scope = get_distribution_scope("gpu")
    with distribute_scope():
        # Your code here
        # e.g., build and train a model
 
# TensorFlow backend distribute_scope = get_distribution_scope("tpu") with distribute_scope(): # Your code here # e.g., build and train a model