Image classification model training¶
Let's build a deep learning image classification model using Keras / Tensorflow, and convert it to ONNX format.
In [1]:
import json
from pathlib import Path
import keras
import numpy as np
import onnx
import onnxruntime as ort
import requests
import tf2onnx
2024-03-11 14:30:29.990881: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. 2024-03-11 14:30:30.021521: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-03-11 14:30:30.021554: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-03-11 14:30:30.022687: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2024-03-11 14:30:30.028375: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. 2024-03-11 14:30:30.028856: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-11 14:30:31.061319: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Load a MobileNetV2 image classification model (based on the ImageNet specifications):
In [2]:
model = keras.applications.mobilenet_v2.MobileNetV2()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224.h5
8192/14536120 [..............................] - ETA: 0s
925696/14536120 [>.............................] - ETA: 0s
11042816/14536120 [=====================>........] - ETA: 0s
14536120/14536120 [==============================] - 0s 0us/step
Load an example image tensor of shape (1, 224, 224, 3):
In [3]:
img_path = "cat.jpg"
img = keras.utils.load_img(img_path, target_size=(224, 224))
x = keras.utils.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = keras.applications.mobilenet_v2.preprocess_input(x)
img
Out[3]:
Try the model on the sample image and decode the top 3 predicted classes:
In [4]:
preds = model.predict(x)
keras.applications.mobilenet_v2.decode_predictions(preds, top=3)
1/1 [==============================] - ETA: 0s
1/1 [==============================] - 1s 849ms/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json
8192/35363 [=====>........................] - ETA: 0s
35363/35363 [==============================] - 0s 0us/step
Out[4]:
[[('n02123045', 'tabby', 0.56931037), ('n02124075', 'Egyptian_cat', 0.26797795), ('n02123159', 'tiger_cat', 0.116148554)]]
Extract the ImageNet class index JSON file:
In [5]:
rv = requests.get(
"https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
)
imagenet_class_index = rv.json()
(Path() / "imagenet_class_index.json").write_text(json.dumps(imagenet_class_index))
Out[5]:
35363
Export the model to ONNX format using tf2onnx
:
In [6]:
onnx_model, _ = tf2onnx.convert.from_keras(model)
onnx_model_path = Path() / "model.onnx"
onnx.save(onnx_model, "model.onnx")
2024-03-11 14:30:36.011665: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 2024-03-11 14:30:36.011806: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-03-11 14:30:37.732081: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 2024-03-11 14:30:37.732443: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
Load the ONNX model, run an inference on the sample image and compute the predicted class name:
In [7]:
session = ort.InferenceSession(onnx_model_path, providers=ort.get_available_providers())
output = session.run(None, {"input_1": x})[0]
imagenet_class_index[str(np.argmax(output[0]))][1]
Out[7]:
'tabby'