.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto/integrations/kubernetes/kftensorflow/tf_mnist.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_integrations_kubernetes_kftensorflow_tf_mnist.py: Distributed TensorFlow Training ------------------------------- When you need to scale up model training using TensorFlow, you can use :py:class:`~tensorflow:tf.distribute.Strategy` to distribute your training across multiple devices. There are various strategies available under this API and you can use any of them. In this example, we will use :py:class:`~tensorflow:tf.distribute.MirroredStrategy` to train an MNIST model using a convolutional network. :py:class:`~tensorflow:tf.distribute.MirroredStrategy` supports synchronous distributed training on multiple GPUs on one machine. To learn more about distributed training with TensorFlow, refer to the `Distributed training with TensorFlow `__ in the TensorFlow documentation. Let's get started with an example! .. GENERATED FROM PYTHON SOURCE LINES 15-16 First, we load the libraries. .. GENERATED FROM PYTHON SOURCE LINES 16-27 .. code-block:: default import os from dataclasses import dataclass from typing import NamedTuple, Tuple import tensorflow as tf import tensorflow_datasets as tfds from dataclasses_json import dataclass_json from flytekit import Resources, task, workflow from flytekit.types.directory import FlyteDirectory from flytekitplugins.kftensorflow import TfJob .. GENERATED FROM PYTHON SOURCE LINES 28-29 We define ``MODEL_FILE_PATH`` indicating where to store the model file. .. GENERATED FROM PYTHON SOURCE LINES 29-31 .. code-block:: default MODEL_FILE_PATH = "saved_model/" .. GENERATED FROM PYTHON SOURCE LINES 32-33 We initialize a data class to store the hyperparameters. .. GENERATED FROM PYTHON SOURCE LINES 33-43 .. code-block:: default @dataclass_json @dataclass class Hyperparameters(object): batch_size_per_replica: int = 64 buffer_size: int = 10000 epochs: int = 10 .. GENERATED FROM PYTHON SOURCE LINES 44-48 Loading the Data ================ We use the `MNIST `__ dataset to train our model. .. GENERATED FROM PYTHON SOURCE LINES 48-75 .. code-block:: default def load_data( hyperparameters: Hyperparameters, ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.distribute.Strategy]: datasets, _ = tfds.load(name="mnist", with_info=True, as_supervised=True) mnist_train, mnist_test = datasets["train"], datasets["test"] strategy = tf.distribute.MirroredStrategy() print("Number of devices: {}".format(strategy.num_replicas_in_sync)) # strategy.num_replicas_in_sync returns the number of replicas; helpful to utilize the extra compute power by increasing the batch size BATCH_SIZE = hyperparameters.batch_size_per_replica * strategy.num_replicas_in_sync def scale(image, label): image = tf.cast(image, tf.float32) image /= 255 return image, label # fetch train and evaluation datasets train_dataset = ( mnist_train.map(scale).shuffle(hyperparameters.buffer_size).batch(BATCH_SIZE) ) eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE) return train_dataset, eval_dataset, strategy .. GENERATED FROM PYTHON SOURCE LINES 76-80 Compiling the Model =================== We create and compile a model in the context of `Strategy.scope `__. .. GENERATED FROM PYTHON SOURCE LINES 80-103 .. code-block:: default def get_compiled_model(strategy: tf.distribute.Strategy) -> tf.keras.Model: with strategy.scope(): model = tf.keras.Sequential( [ tf.keras.layers.Conv2D( 32, 3, activation="relu", input_shape=(28, 28, 1) ), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(64, activation="relu"), tf.keras.layers.Dense(10), ] ) model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.Adam(), metrics=["accuracy"], ) return model .. GENERATED FROM PYTHON SOURCE LINES 104-108 Training ======== We define a function for decaying the learning rate. .. GENERATED FROM PYTHON SOURCE LINES 108-117 .. code-block:: default def decay(epoch: int): if epoch < 3: return 1e-3 elif epoch >= 3 and epoch < 7: return 1e-4 else: return 1e-5 .. GENERATED FROM PYTHON SOURCE LINES 118-123 Next, we define ``train_model`` to train the model with three callbacks: * :py:class:`~tensorflow:tf.keras.callbacks.TensorBoard` to log the training metrics * :py:class:`~tensorflow:tf.keras.callbacks.ModelCheckpoint` to save the model after every epoch * :py:class:`~tensorflow:tf.keras.callbacks.LearningRateScheduler` to decay the learning rate .. GENERATED FROM PYTHON SOURCE LINES 123-162 .. code-block:: default def train_model( model: tf.keras.Model, train_dataset: tf.data.Dataset, hyperparameters: Hyperparameters, ) -> Tuple[tf.keras.Model, str]: # define the checkpoint directory to store the checkpoints checkpoint_dir = "./training_checkpoints" # define the name of the checkpoint files checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}") # define a callback for printing the learning rate at the end of each epoch class PrintLR(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): print( "\nLearning rate for epoch {} is {}".format( epoch + 1, model.optimizer.lr.numpy() ) ) # put all the callbacks together callbacks = [ tf.keras.callbacks.TensorBoard(log_dir="./logs"), tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_prefix, save_weights_only=True ), tf.keras.callbacks.LearningRateScheduler(decay), PrintLR(), ] # train the model model.fit(train_dataset, epochs=hyperparameters.epochs, callbacks=callbacks) # save the model model.save(MODEL_FILE_PATH, save_format="tf") return model, checkpoint_dir .. GENERATED FROM PYTHON SOURCE LINES 163-167 Evaluation ========== We define ``test_model`` to evaluate loss and accuracy on the test dataset. .. GENERATED FROM PYTHON SOURCE LINES 167-177 .. code-block:: default def test_model( model: tf.keras.Model, checkpoint_dir: str, eval_dataset: tf.data.Dataset ) -> Tuple[float, float]: model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)) eval_loss, eval_acc = model.evaluate(eval_dataset) return eval_loss, eval_acc .. GENERATED FROM PYTHON SOURCE LINES 178-193 Defining an MNIST TensorFlow Task ================================== We initialize compute requirements and task output signature. Next, we define a ``mnist_tensorflow_job`` to kick off the training and evaluation process. The task is initialized with ``TFJob`` with certain values set: * ``num_workers``: integer determining the number of worker replicas to be spawned in the cluster for this job * ``num_ps_replicas``: number of parameter server replicas to use * ``num_chief_replicas``: number of chief replicas to use MirroredStrategy uses an all-reduce algorithm to communicate the variable updates across the devices. Hence, ``num_ps_replicas`` is not useful in our example. .. note:: If you'd like to understand the various Tensorflow strategies in distributed training, refer to the `Types of strategies `__ section in the TensorFlow documentation. .. GENERATED FROM PYTHON SOURCE LINES 193-229 .. code-block:: default training_outputs = NamedTuple( "TrainingOutputs", accuracy=float, loss=float, model_state=FlyteDirectory ) if os.getenv("SANDBOX") != "": resources = Resources( gpu="0", mem="1000Mi", storage="500Mi", ephemeral_storage="500Mi" ) else: resources = Resources( gpu="2", mem="10Gi", storage="10Gi", ephemeral_storage="500Mi" ) @task( task_config=TfJob(num_workers=2, num_ps_replicas=1, num_chief_replicas=1), retries=2, cache=True, cache_version="1.0", requests=resources, limits=resources, ) def mnist_tensorflow_job(hyperparameters: Hyperparameters) -> training_outputs: train_dataset, eval_dataset, strategy = load_data(hyperparameters=hyperparameters) model = get_compiled_model(strategy=strategy) model, checkpoint_dir = train_model( model=model, train_dataset=train_dataset, hyperparameters=hyperparameters ) eval_loss, eval_accuracy = test_model( model=model, checkpoint_dir=checkpoint_dir, eval_dataset=eval_dataset ) return training_outputs( accuracy=eval_accuracy, loss=eval_loss, model_state=MODEL_FILE_PATH ) .. GENERATED FROM PYTHON SOURCE LINES 230-234 Workflow ======== Finally we define a workflow to call the ``mnist_tensorflow_job`` task. .. GENERATED FROM PYTHON SOURCE LINES 234-241 .. code-block:: default @workflow def mnist_tensorflow_workflow( hyperparameters: Hyperparameters = Hyperparameters(), ) -> training_outputs: return mnist_tensorflow_job(hyperparameters=hyperparameters) .. GENERATED FROM PYTHON SOURCE LINES 242-243 We can also run the code locally. .. GENERATED FROM PYTHON SOURCE LINES 243-245 .. code-block:: default if __name__ == "__main__": print(mnist_tensorflow_workflow()) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_auto_integrations_kubernetes_kftensorflow_tf_mnist.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tf_mnist.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tf_mnist.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_