一、模型训练
1. 为保证后续的模型转换可能会失败,选取了ModelZoo-PyTorch中支持的代码仓库dbnet来进行训练,该仓库地址为dbnet作者原始的代码仓库。
2.配置相关训练参数,训练一个dbnet可用模型,因使用场景在印刷体文字,不需训练多长时间便可以得到一个较好效果的pt格式模型文件。
二、模型转换
1. 转换onnx文件
python3 convert_onnx.py
2. 转换om文件
om文件与机器有关,转换时在昇腾910机器上进行,使用atc工具进行转换,转换脚本如下:
atc --model=db_ic15_resnet50_16.onnx --framework=5 --output=db_ic15_resnet50_16 --input_format=NCHW --input_shape="input:1,3,960,960" --log=error --soc_version=Ascend910B3
其中 soc_version 指定机器的型号
三、模型测试
昇腾提供了一个名为benchmark的python API,可以用来进行离线模型(.om模型)推理。
ais_bench推理工具的安装包括aclruntime包和ais_bench推理程序包的安装。安装方式有多种,可以使用whl包进行安装。安装命令如下:
pip install aclruntime-0.0.2-cp310-cp310-linux_aarch64.whl
pip install ais_bench-0.0.2-py3-none-any.whl
安装成功后,便可以引入ais_bench进行推理。
ais_bench的使用包括导入依赖包、加载模型、图像预处理、调用接口推理模型得到输出、图像后处理、释放模型占用的内存几个步骤。
以下给出dbnet的完整测试代码:
import argparse
from tqdm import tqdm
import numpy as np
import glob
import cv2
from ais_bench.infer.interface import InferSession
def resize_image(img):
input_w = 960
input_h = 960
h, w, c = img.shape
r_w = input_w / w
r_h = input_h / h
if r_h > r_w:
tw = input_w
th = int(r_w * h)
tx1 = tx2 = 0
ty1 = 0
ty2 = input_h - th
else:
tw = int(r_h * w)
th = input_h
tx1 = 0
tx2 = input_w - tw - tx1
ty1 = ty2 = 0
# Resize the image with long side while maintaining ratio
resized_img = cv2.resize(img, (tw, th))
resized_img = cv2.copyMakeBorder(
resized_img, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, None, (0, 0, 0)
)
return resized_img, tw, th
def transfer_pic(origin_image):
# 图像预处理
dbnet_input_data, tw, th = resize_image(origin_image)
dbnet_input_data = dbnet_input_data.astype(np.float16)
dbnet_input_data -= np.array(
[122.67891434, 116.66876762, 104.00698793])
dbnet_input_data /= 255.0
dbnet_input_data = dbnet_input_data.transpose([2, 0, 1])
dbnet_input_data = dbnet_input_data[np.newaxis, :]
return dbnet_input_data, tw, th
def get_bin(pred):
pred = pred[0][0]
_bitmap = pred > 0.3
bitmap = _bitmap
height, width = bitmap.shape
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
bitmap = (bitmap * 255).astype(np.uint8)
bitmap = cv2.dilate(bitmap, kernel)
# outs = cv2.findContours(bitmap, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
return bitmap
def main(data_path, npu_session):
files = glob.glob(data_path+'/*.png')+glob.glob(data_path+'/*.jpg')
for data in tqdm(files):
data = cv2.imread(data)
data, new_w,new_h = transfer_pic(data)
npu_result = npu_session.infer(data, "static")
bitmap = get_bin(npu_result[0])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='infer E2E')
parser.add_argument("--data_path", default="img_test", help='data path')
parser.add_argument("--device", default=0, type=int, help='npu device')
parser.add_argument("--om_path", default="db_ic15_resnet50_16.om", help='om path')
flags = parser.parse_args()
db_session = InferSession(flags.device, flags.om_path)
main(flags.data_path,db_session)