从 PyTorch 到 TensorFlow Lite:模型训练与推理
【代码】从 PyTorch 到 TensorFlow Lite:模型训练与推理。
·
一、方案介绍
-
研发阶段:
- 灵活性与易用性:利用 PyTorch 的动态图特性,快速构建和调试模型,适合研究和实验。
- 快速原型开发:支持快速迭代和验证新想法,便于模型设计。
-
转换阶段:
- 模型转换:将训练好的 PyTorch 模型通过 TorchScript 导出为 ONNX 格式,再转换为 TensorFlow 格式,最终生成 TFLite 模型。
- 优化移动设备:TensorFlow Lite 专为移动和嵌入式设备设计,支持模型量化和优化,提升推理速度和效率。
-
部署阶段:
- 高效运行:将 TFLite 模型集成到 Android、iOS 或嵌入式系统中,减少内存占用和计算资源消耗。
- 优化技术:使用模型量化等技术,提高推理速度,延长设备电池寿命,提升用户体验。
二、实例1:CNN模型的转换
注:python 版本为3.10
2.1 pth转onnx并一致性对比
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(x, 2)
x = nn.functional.relu(self.conv2(x))
x = nn.functional.max_pool2d(x, 2)
x = x.view(-1, 64 * 7 * 7)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
device = torch.device("cpu")
print(f"Using device: {device}")
model = CNNModel().to(device)
torch.save(model.state_dict(), 'cnn_mnist.pth')
print("Model saved as cnn_mnist.pth")
model.load_state_dict(torch.load('cnn_mnist.pth', weights_only=True)) # 加载保存的模型权重
model.eval() # 设置为评估模式
device = torch.device('cpu')
# 创建一个示例输入
dummy_input = torch.randn(1, 1, 28, 28).to(device) # MNIST 图像的形状
# 使用 PyTorch 进行推理
with torch.no_grad():
pytorch_output = model(dummy_input)
# 导出模型为 ONNX 格式
torch.onnx.export(model, dummy_input, 'cnn_mnist.onnx', export_params=True, opset_version=11)
print("Model exported to cnn_mnist.onnx")
# 使用 ONNX 进行推理
onnx_model = onnx.load('cnn_mnist.onnx')
ort_session = ort.InferenceSession('cnn_mnist.onnx')
# 准备输入数据
onnx_input = dummy_input.numpy() # 将 PyTorch 张量转换为 NumPy 数组
onnx_input = onnx_input.astype(np.float32) # 确保数据类型为 float32
# 使用 ONNX 进行推理
onnx_output = ort_session.run(None, {ort_session.get_inputs()[0].name: onnx_input})
# 比较输出
pytorch_output_np = pytorch_output.numpy() # 将 PyTorch 输出转换为 NumPy 数组
onnx_output_np = onnx_output[0] # ONNX 输出是一个列表,取第一个元素
if np.allclose(pytorch_output_np, onnx_output_np, atol=1e-5):
print("PyTorch 和 ONNX 的输出是一致的。")
else:
print("PyTorch 和 ONNX 的输出不一致。")
# 打印输出结果
print("PyTorch output:", pytorch_output_np)
print("ONNX output:", onnx_output_np)
Using device: cpu
Model saved as cnn_mnist.pth
Model exported to cnn_mnist.onnx
PyTorch 和 ONNX 的输出是一致的。
PyTorch output: [[ 0.09420984 -0.22566104 -0.09049285 -0.12713039 0.16126882 -0.03506599
-0.06087608 -0.08624509 0.14606059 -0.10252683]]
ONNX output: [[ 0.09420987 -0.22566104 -0.0904929 -0.12713037 0.16126883 -0.03506595
-0.06087613 -0.08624507 0.14606053 -0.10252687]]
2.2 onnx模型转tflite
参考这个项目:onnx2tflite
git clone https://github.com/MPolaris/onnx2tflite.git
cd onnx2tflite
conda create -n onnx2tf python=3.10
conda activate onnx2tf
conda install tensorflow==2.11.0
pip install opencv_python~=4.8.0
pip install onnx==1.17.0
pip install onnxruntime==1.19.2
pip install onnxsim==0.4.36
python -m onnx2tflite --weights ../cnn_mnist.onnx
(onnx2tf) ➜ onnx2tflite python -m onnx2tflite --weights ../cnn_mnist.onnx
Checking 0/1...
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /var/folders/p1/0hjnc12n4jz2j_flm0lnr3hr0000gn/T/tmpgd0n9r_5/assets
INFO:tensorflow:Assets written to: /var/folders/p1/0hjnc12n4jz2j_flm0lnr3hr0000gn/T/tmpgd0n9r_5/assets
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:converter running::keras model saved in ../cnn_mnist.h5
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
1/1 [==============================] - 0s 28ms/step
INFO: Initialized TensorFlow Lite runtime.
INFO: Applying 1 TensorFlow Lite delegate(s) lazily.
INFO:converter running::h5 model elements' max error is 3.7253E-08, pass, h5 saved in ../cnn_mnist.h5
INFO:converter running::tflite model elements' max error is 4.4703E-08, pass, tflite saved in ../cnn_mnist.tflite
2.3 onnx模型和tflite一致性验证
import numpy as np
import onnxruntime as ort
import tensorflow as tf
# 1. 加载 ONNX 模型
onnx_model_path = 'cnn_mnist.onnx'
onnx_session = ort.InferenceSession(onnx_model_path)
# 2. 加载 TFLite 模型
tflite_model_path = 'cnn_mnist.tflite'
tflite_interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
tflite_interpreter.allocate_tensors()
# 3. 准备输入数据
# 假设输入数据是 MNIST 数据集的一部分,形状为 (1, 28, 28, 1)
input_data = np.random.rand(1, 28, 28, 1).astype(np.float32) # Keras 输入
input_data_onnx = input_data.transpose(0, 3, 1, 2) # 转换为 ONNX 输入格式 (1, 1, 28, 28)
# 4. 使用相同的输入数据进行推理
# ONNX 模型推理
onnx_input_name = onnx_session.get_inputs()[0].name
onnx_output = onnx_session.run(None, {onnx_input_name: input_data_onnx})[0]
print("ONNX Output:", onnx_output)
# TFLite 模型推理
tflite_input_details = tflite_interpreter.get_input_details()
tflite_output_details = tflite_interpreter.get_output_details()
# 检查 TFLite 输入形状
print("TFLite Input Shape:", tflite_input_details[0]['shape'])
# 设置 TFLite 输入
# 确保输入数据的形状与 TFLite 模型的输入要求一致
tflite_interpreter.set_tensor(tflite_input_details[0]['index'], input_data)
tflite_interpreter.invoke()
tflite_output = tflite_interpreter.get_tensor(tflite_output_details[0]['index'])
print("TFLite Output:", tflite_output)
# 5. 比较输出结果
# 计算输出的差异
onnx_difference = np.abs(onnx_output - tflite_output)
# 输出结果
print("Difference (ONNX vs TFLite):", onnx_difference)
# 检查是否一致
if np.all(onnx_difference < 1e-5): # 设定一个阈值
print("The outputs are consistent between ONNX and TFLite models.")
else:
print("The outputs are not consistent between ONNX and TFLite models.")
ONNX Output: [[ -3.7372704 -6.5073314 -1.1807165 -2.4232314 -10.638929 2.2660115
-4.5868526 -2.7494073 -0.5609715 -6.331989 ]]
TFLite Input Shape: [ 1 28 28 1]
TFLite Output: [[ -3.7372704 -6.5073323 -1.180716 -2.4232314 -10.638928
2.2660117 -4.5868545 -2.7494078 -0.56097114 -6.331988 ]]
Difference (ONNX vs TFLite): [[0.0000000e+00 9.5367432e-07 4.7683716e-07 0.0000000e+00 9.5367432e-07
2.3841858e-07 1.9073486e-06 4.7683716e-07 3.5762787e-07 9.5367432e-07]]
The outputs are consistent between ONNX and TFLite models.
更多推荐
所有评论(0)