training_utils
This module provides utilities for distributed training with Keras and TensorFlow.
Functions: get_distribution_scope(device_type: str) -> ContextManager Returns a context manager for executing code in a distributed environment.
Usage:
To use the get_distribution_scope
function, first set the KERAS_BACKEND
environment variable to either 'jax'
or 'tensorflow'
. Then, call the function with the desired device type ('cpu', 'gpu', or 'tpu').
Returns a context manager for executing code in a distributed environment.
This function supports both JAX and TensorFlow backends. For the JAX backend, it supports CPU, GPU, and TPU devices, while for the TensorFlow backend, it supports CPU, GPU, and TPU devices with appropriate distribution strategies.
The context manager returned by this function prints the total number of available devices and the total time taken for the code executed within the context manager.
Args:
- device_type (str): The type of device to use for distributed training. Can be 'cpu', 'gpu', or 'tpu'.
Returns:
- A context manager object that can be used to execute code in a distributed environment.
Notes:
- For the JAX backend, this function uses the
jax.distribution
module to create aDataParallel
distribution for GPU devices. For more information, see the Keras guide on distributed training with JAX: https://keras.io/guides/distribution/
For the TensorFlow backend, this function uses the appropriate distribution strategy based on the device type: - For CPU and GPU, it uses
tf.distribute.MirroredStrategy
- For TPU, it usestf.distribute.TPUStrategy
For more information on distributed training with TensorFlow, see the TensorFlow guide: https://www.tensorflow.org/guide/distributed_training
Raises:
- ValueError: If an unsupported device type or backend is provided.
Examples:
# JAX backend
distribute_scope = get_distribution_scope("gpu")
with distribute_scope():
# Your code here
# e.g., build and train a model
# TensorFlow backend
distribute_scope = get_distribution_scope("tpu")
with distribute_scope():
# Your code here
# e.g., build and train a model