.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto/integrations/flytekit_plugins/onnx_examples/scikitlearn_onnx.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_integrations_flytekit_plugins_onnx_examples_scikitlearn_onnx.py: ScikitLearn Example ------------------- In this example, we will see how to convert a scikitlearn model to an ONNX model. First import the necessary libraries. .. GENERATED FROM PYTHON SOURCE LINES 9-23 .. code-block:: default from typing import List, NamedTuple import numpy import onnxruntime as rt import pandas as pd from flytekit import task, workflow from flytekit.types.file import ONNXFile from flytekitplugins.onnxscikitlearn import ScikitLearn2ONNX, ScikitLearn2ONNXConfig from skl2onnx.common.data_types import FloatTensorType from sklearn.datasets import load_iris from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from typing_extensions import Annotated .. GENERATED FROM PYTHON SOURCE LINES 24-27 Define a ``NamedTuple`` to hold the output schema. Note the annotation on the ``model`` field. This is a special annotation that tells Flytekit that this parameter is to be converted to an ONNX model with the given metadata. .. GENERATED FROM PYTHON SOURCE LINES 27-45 .. code-block:: default TrainOutput = NamedTuple( "TrainOutput", [ ( "model", Annotated[ ScikitLearn2ONNX, ScikitLearn2ONNXConfig( initial_types=[("float_input", FloatTensorType([None, 4]))], target_opset=12, ), ], ), ("test", pd.DataFrame), ], ) .. GENERATED FROM PYTHON SOURCE LINES 46-47 Define a ``train`` task that will train a scikitlearn model and return the model and test data. .. GENERATED FROM PYTHON SOURCE LINES 47-58 .. code-block:: default @task def train() -> TrainOutput: iris = load_iris(as_frame=True) X, y = iris.data, iris.target X_train, X_test, y_train, _ = train_test_split(X, y) model = RandomForestClassifier() model.fit(X_train, y_train) return TrainOutput(test=X_test, model=ScikitLearn2ONNX(model)) .. GENERATED FROM PYTHON SOURCE LINES 59-60 Define a ``predict`` task that will use the model to predict the labels for the test data. .. GENERATED FROM PYTHON SOURCE LINES 60-74 .. code-block:: default @task def predict( model: ONNXFile, X_test: pd.DataFrame, ) -> List[int]: sess = rt.InferenceSession(model.download()) input_name = sess.get_inputs()[0].name label_name = sess.get_outputs()[0].name pred_onx = sess.run( [label_name], {input_name: X_test.to_numpy(dtype=numpy.float32)} )[0] return pred_onx.tolist() .. GENERATED FROM PYTHON SOURCE LINES 75-76 Lastly define a workflow to run the above tasks. .. GENERATED FROM PYTHON SOURCE LINES 76-82 .. code-block:: default @workflow def wf() -> List[int]: train_output = train() return predict(model=train_output.model, X_test=train_output.test) .. GENERATED FROM PYTHON SOURCE LINES 83-84 Run the workflow locally. .. GENERATED FROM PYTHON SOURCE LINES 84-86 .. code-block:: default if __name__ == "__main__": print(f"Predictions: {wf()}") .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_auto_integrations_flytekit_plugins_onnx_examples_scikitlearn_onnx.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: scikitlearn_onnx.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: scikitlearn_onnx.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_