From 1002d921ab3bb1e56e7a29d3db1e270476cf34f0 Mon Sep 17 00:00:00 2001 From: anders Date: Thu, 6 Nov 2025 19:13:15 +0800 Subject: [PATCH 1/3] update yolov10 --- .../object_detection/yolov10/igie/README.md | 3 +- .../object_detection/yolov10/igie/quantize.py | 119 ++++++++++ .../scripts/infer_yolov10_int8_accuracy.sh | 46 ++++ .../scripts/infer_yolov10_int8_performance.sh | 47 ++++ .../cv/object_detection/yolov10/igie/utils.py | 208 ++++++++++++++++++ 5 files changed, 422 insertions(+), 1 deletion(-) create mode 100644 models/cv/object_detection/yolov10/igie/quantize.py create mode 100644 models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_accuracy.sh create mode 100644 models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh create mode 100644 models/cv/object_detection/yolov10/igie/utils.py diff --git a/models/cv/object_detection/yolov10/igie/README.md b/models/cv/object_detection/yolov10/igie/README.md index 393d12d1..12dba10a 100644 --- a/models/cv/object_detection/yolov10/igie/README.md +++ b/models/cv/object_detection/yolov10/igie/README.md @@ -85,7 +85,8 @@ bash scripts/infer_yolov10_fp16_performance.sh | Model | BatchSize | Precision | FPS | IOU@0.5 | IOU@0.5:0.95 | | ------- | --------- | --------- | ------ | ------- | ------------ | -| YOLOv10 | 32 | FP16 | 810.97 | 0.629 | 0.461 | +| YOLOv10 | 32 | FP16 | 528.685 | 0.629 | 0.461 | +| YOLOv10 | 32 | INT8 | 599.318 | 0.618 | 0.444 | ## References diff --git a/models/cv/object_detection/yolov10/igie/quantize.py b/models/cv/object_detection/yolov10/igie/quantize.py new file mode 100644 index 00000000..522e00fc --- /dev/null +++ b/models/cv/object_detection/yolov10/igie/quantize.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import onnx +import psutil +import argparse +import numpy as np +from inference import get_dataloader +from onnxruntime.quantization import (CalibrationDataReader, QuantFormat, + quantize_static, QuantType, + CalibrationMethod) + +class CalibrationDataLoader(CalibrationDataReader): + def __init__(self, input_name, dataloader, cnt_limit=100): + self.cnt = 0 + self.input_name = input_name + self.cnt_limit = cnt_limit + self.iter = iter(dataloader) + + # avoid oom + @staticmethod + def _exceed_memory_upper_bound(upper_bound=80): + info = psutil.virtual_memory() + total_percent = info.percent + if total_percent >= upper_bound: + return True + return False + + def get_next(self): + if self._exceed_memory_upper_bound() or self.cnt >= self.cnt_limit: + return None + self.cnt += 1 + print(f"onnx calibration data count: {self.cnt}") + input_info = next(self.iter) + + ort_input = {k: np.array(v) for k, v in zip(self.input_name, input_info)} + return ort_input + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", + type=str, + required=True, + help="original model path.") + + parser.add_argument("--out_path", + type=str, + required=True, + help="igie export engine path.") + + parser.add_argument("--datasets", + type=str, + required=True, + help="calibration datasets path.") + + parser.add_argument("--num_workers", + type=int, + default=16, + help="number of workers used in pytorch dataloader.") + + args = parser.parse_args() + + return args + +def main(): + args = parse_args() + + model = onnx.load(args.model_path) + input_names = [input.name for input in model.graph.input] + + data_path = os.path.join(args.datasets, "images", "val2017") + label_path = os.path.join(args.datasets, "annotations", "instances_val2017.json") + + dataloader = get_dataloader(data_path, label_path, batch_size=1, num_workers=args.num_workers) + calibration = CalibrationDataLoader(input_names, dataloader, cnt_limit=20) + + quantize_static(args.model_path, + args.out_path, + calibration_data_reader=calibration, + quant_format=QuantFormat.QOperator, + op_types_to_quantize=['Conv'], + per_channel=False, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + use_external_data_format=False, + nodes_to_exclude=[ + '/model.23/Add_9', + '/model.23/Concat_24', + '/model.23/Concat_25', + '/model.23/Concat_26', + '/model.23/Concat_28', + '/model.23/Concat_30', + '/model.23/Concat_32', + '/model.23/Mul_3', + '/model.10/attn/Softmax', + '/model.23/dfl/Softmax'], + calibrate_method=CalibrationMethod.Percentile, + extra_options = { + 'ActivationSymmetric': True, + 'WeightSymmetric': True + } + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_accuracy.sh b/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_accuracy.sh new file mode 100644 index 00000000..d2b7520d --- /dev/null +++ b/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_accuracy.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +batchsize=32 +model_path="yolov10s.onnx" + +quantized_model_path="yolov10s_int8.onnx" + +datasets_path=${DATASETS_DIR} + +if [ ! -e $quantized_model_path ]; then + # quantize model to int8 + python3 quantize.py \ + --model_path ${model_path} \ + --out_path ${quantized_model_path} \ + --datasets ${datasets_path} +fi + + +# build engine +python3 ../../igie_common/build_engine.py \ + --model_path ${quantized_model_path} \ + --input images:${batchsize},3,640,640 \ + --precision int8 \ + --engine_path yolov10_bs_${batchsize}_int8.so + +# inference +python3 inference.py \ + --engine yolov10_bs_${batchsize}_int8.so \ + --batchsize ${batchsize} \ + --input_name images \ + --datasets ${datasets_path} diff --git a/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh b/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh new file mode 100644 index 00000000..6829fde1 --- /dev/null +++ b/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +batchsize=32 +model_path="yolov10s.onnx" + +quantized_model_path="yolov10s_int8.onnx" + +datasets_path=${DATASETS_DIR} + +if [ ! -e $quantized_model_path ]; then + # quantize model to int8 + python3 quantize.py \ + --model_path ${model_path} \ + --out_path ${quantized_model_path} \ + --batch ${batchsize} \ + --datasets ${datasets_path} +fi + +# build engine +python3 ../../igie_common/build_engine.py \ + --model_path ${quantized_model_path} \ + --input images:${batchsize},3,640,640 \ + --precision int8 \ + --engine_path yolov10_bs_${batchsize}_int8.so + +# inference +python3 inference.py \ + --engine yolov10_bs_${batchsize}_int8.so \ + --batchsize ${batchsize} \ + --input_name images \ + --datasets ${datasets_path} \ + --perf_only True diff --git a/models/cv/object_detection/yolov10/igie/utils.py b/models/cv/object_detection/yolov10/igie/utils.py new file mode 100644 index 00000000..1b35ad76 --- /dev/null +++ b/models/cv/object_detection/yolov10/igie/utils.py @@ -0,0 +1,208 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import cv2 +import torch +import numpy as np + +from pycocotools.coco import COCO + +coco80_to_coco91 = [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, + 89, 90 +] + +coco91_to_coco80_dict = {i: idx for idx, i in enumerate(coco80_to_coco91)} + +def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)): + # Resize and pad image while meeting stride-multiple constraints + # current shape [height, width] + + shape = im.shape[:2] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + + # Compute padding + ratio = r, r + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] + + dw /= 2 + dh /= 2 + + if shape[::-1] != new_unpad: + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, + top, + bottom, + left, + right, + cv2.BORDER_CONSTANT, + value=color) + return im, ratio, (dw, dh) + +def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): + # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x + y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y + y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x + y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y + return y + +def clip_boxes(boxes, shape): + # Clip boxes (xyxy) to image shape (height, width) + if isinstance(boxes, torch.Tensor): # faster individually + boxes[:, 0].clamp_(0, shape[1]) # x1 + boxes[:, 1].clamp_(0, shape[0]) # y1 + boxes[:, 2].clamp_(0, shape[1]) # x2 + boxes[:, 3].clamp_(0, shape[0]) # y2 + else: # np.array (faster grouped) + boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2 + boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2 + +def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right + if clip: + clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center + y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center + y[:, 2] = (x[:, 2] - x[:, 0]) / w # width + y[:, 3] = (x[:, 3] - x[:, 1]) / h # height + return y + +class COCO2017Dataset(torch.utils.data.Dataset): + def __init__(self, + image_dir_path, + label_json_path, + image_size=640, + pad_color=114, + val_mode=True, + input_layout="NCHW"): + + self.image_dir_path = image_dir_path + self.label_json_path = label_json_path + self.image_size = image_size + self.pad_color = pad_color + self.val_mode = val_mode + self.input_layout = input_layout + + self.coco = COCO(annotation_file=self.label_json_path) + + if self.val_mode: + self.img_ids = list(sorted(self.coco.imgs.keys())) + else: + self.img_ids = sorted(list(self.coco.imgToAnns.keys())) + + def __len__(self): + return len(self.img_ids) + + def __getitem__(self, index): + img_path = self._get_image_path(index) + img, (h0, w0), (h, w) = self._load_image(index) + + img, ratio, pad = letterbox(img, + self.image_size, + color=(self.pad_color, self.pad_color, self.pad_color)) + shapes = (h0, w0), ((h / h0, w / w0), pad) + + # load label + raw_label = self._load_json_label(index) + # normalized xywh to pixel xyxy format + raw_label[:, 1:] = xywhn2xyxy(raw_label[:, 1:], + ratio[0] * w, + ratio[1] * h, + padw=pad[0], + padh=pad[1]) + + raw_label[:, 1:] = xyxy2xywhn(raw_label[:, 1:], + w=img.shape[1], + h=img.shape[0], + clip=True, + eps=1E-3) + + nl = len(raw_label) + labels_out = np.zeros((nl, 6)) + labels_out[:, 1:] = raw_label + + # HWC to CHW, BGR to RGB + img = img.transpose((2, 0, 1))[::-1] + img = np.ascontiguousarray(img) / 255.0 + if self.input_layout == "NHWC": + img = img.transpose((1, 2, 0)) + + return img, labels_out, img_path, shapes + + def _get_image_path(self, index): + idx = self.img_ids[index] + path = self.coco.loadImgs(idx)[0]["file_name"] + img_path = os.path.join(self.image_dir_path, path) + return img_path + + def _load_image(self, index): + img_path = self._get_image_path(index) + + im = cv2.imread(img_path) + h0, w0 = im.shape[:2] + r = self.image_size / max(h0, w0) + if r != 1: + im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) + return im.astype("float32"), (h0, w0), im.shape[:2] + + def _load_json_label(self, index): + _, (h0, w0), _ = self._load_image(index) + + idx = self.img_ids[index] + ann_ids = self.coco.getAnnIds(imgIds=idx) + targets = self.coco.loadAnns(ids=ann_ids) + + labels = [] + for target in targets: + cat = target["category_id"] + coco80_cat = coco91_to_coco80_dict[cat] + cat = np.array([[coco80_cat]]) + + x, y, w, h = target["bbox"] + x1, y1, x2, y2 = x, y, int(x + w), int(y + h) + xyxy = np.array([[x1, y1, x2, y2]]) + xywhn = xyxy2xywhn(xyxy, w0, h0) + labels.append(np.hstack((cat, xywhn))) + + if labels: + labels = np.vstack(labels) + else: + if self.val_mode: + labels = np.zeros((1, 5)) + else: + raise ValueError(f"set val_mode = False to use images with labels") + + return labels + + @staticmethod + def collate_fn(batch): + im, label, path, shapes = zip(*batch) + for i, lb in enumerate(label): + lb[:, 0] = i + return np.concatenate([i[None] for i in im], axis=0), np.concatenate(label, 0), path, shapes -- Gitee From a711cbc0013b88272bc40c9ccf3fbe3fb2a073d7 Mon Sep 17 00:00:00 2001 From: "dun.zhang" Date: Fri, 7 Nov 2025 11:19:24 +0800 Subject: [PATCH 2/3] update yolov10 --- .../cv/object_detection/yolov10/igie/inference.py | 15 +++++++++++++-- .../scripts/infer_yolov10_int8_performance.sh | 1 - 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/models/cv/object_detection/yolov10/igie/inference.py b/models/cv/object_detection/yolov10/igie/inference.py index d7c2430e..d8a41a2d 100644 --- a/models/cv/object_detection/yolov10/igie/inference.py +++ b/models/cv/object_detection/yolov10/igie/inference.py @@ -16,7 +16,7 @@ import os import argparse import tvm from tvm import relay - +import torch import numpy as np from pathlib import Path from ultralytics import YOLO @@ -24,7 +24,7 @@ from ultralytics.cfg import get_cfg from ultralytics.utils import DEFAULT_CFG from validator import IGIE_Validator - +from utils import COCO2017Dataset def parse_args(): parser = argparse.ArgumentParser() @@ -73,6 +73,17 @@ def parse_args(): return args +def get_dataloader(data_path, label_path, batch_size, num_workers): + + dataset = COCO2017Dataset(data_path, label_path, image_size=640) + + dataloader = torch.utils.data.DataLoader(dataset, + batch_size=batch_size, + drop_last=False, + num_workers=num_workers, + collate_fn=dataset.collate_fn) + return dataloader + def main(): args = parse_args() diff --git a/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh b/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh index 6829fde1..6d91b19e 100644 --- a/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh +++ b/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh @@ -27,7 +27,6 @@ if [ ! -e $quantized_model_path ]; then python3 quantize.py \ --model_path ${model_path} \ --out_path ${quantized_model_path} \ - --batch ${batchsize} \ --datasets ${datasets_path} fi -- Gitee From 97df32aece5519c874f92073edf3819b2593b73a Mon Sep 17 00:00:00 2001 From: "hongliang.yuan" Date: Fri, 7 Nov 2025 16:18:56 +0800 Subject: [PATCH 3/3] update yolov10 --- .../object_detection/yolov10/igie/README.md | 28 +++++++++++++++++++ .../yolov10/igie/inference.py | 2 +- .../yolov10/igie/requirements.txt | 3 ++ .../scripts/infer_yolov10_int8_accuracy.sh | 2 +- .../scripts/infer_yolov10_int8_performance.sh | 2 +- tests/model_info.json | 3 +- 6 files changed, 36 insertions(+), 4 deletions(-) diff --git a/models/cv/object_detection/yolov10/igie/README.md b/models/cv/object_detection/yolov10/igie/README.md index 12dba10a..0d5df449 100644 --- a/models/cv/object_detection/yolov10/igie/README.md +++ b/models/cv/object_detection/yolov10/igie/README.md @@ -59,6 +59,25 @@ pip3 install -r requirements.txt ```bash git clone --depth 1 https://github.com/THU-MIG/yolov10.git cd yolov10/ + +``` + +```python +# 修改如下 +--- a/ultralytics/engine/exporter.py ++++ b/ultralytics/engine/exporter.py +@@ -373,6 +373,7 @@ class Exporter: + elif isinstance(self.model, DetectionModel): + dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, + 8400) + ++ dynamic = {'images': {0: 'batch'}, 'output0': {0: 'batch'}} + torch.onnx.export( + self.model.cpu() if dynamic else self.model, # dynamic=True only c +ompatible with cpu +``` + +```bash pip3 install -e . --no-deps cd ../ @@ -81,6 +100,15 @@ bash scripts/infer_yolov10_fp16_accuracy.sh bash scripts/infer_yolov10_fp16_performance.sh ``` +### INT8 + +```bash +# Accuracy +bash scripts/infer_yolov10_int8_accuracy.sh +# Performance +bash scripts/infer_yolov10_int8_performance.sh +``` + ## Model Results | Model | BatchSize | Precision | FPS | IOU@0.5 | IOU@0.5:0.95 | diff --git a/models/cv/object_detection/yolov10/igie/inference.py b/models/cv/object_detection/yolov10/igie/inference.py index d8a41a2d..093eec50 100644 --- a/models/cv/object_detection/yolov10/igie/inference.py +++ b/models/cv/object_detection/yolov10/igie/inference.py @@ -143,7 +143,7 @@ def main(): validator = IGIE_Validator(args=cfg_args, save_dir=Path('.')) validator.stride = 32 - stats = validator(module, device) + validator(module, device) if __name__ == "__main__": main() diff --git a/models/cv/object_detection/yolov10/igie/requirements.txt b/models/cv/object_detection/yolov10/igie/requirements.txt index 682cf964..aff22731 100644 --- a/models/cv/object_detection/yolov10/igie/requirements.txt +++ b/models/cv/object_detection/yolov10/igie/requirements.txt @@ -1,2 +1,5 @@ tqdm huggingface_hub==0.25.2 +onnx==1.15.0 +matplotlib +pycocotools \ No newline at end of file diff --git a/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_accuracy.sh b/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_accuracy.sh index d2b7520d..95121cad 100644 --- a/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_accuracy.sh +++ b/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_accuracy.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may diff --git a/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh b/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh index 6d91b19e..c03fae0f 100644 --- a/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh +++ b/models/cv/object_detection/yolov10/igie/scripts/infer_yolov10_int8_performance.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may diff --git a/tests/model_info.json b/tests/model_info.json index 1ff2d05d..42e6dd4c 100644 --- a/tests/model_info.json +++ b/tests/model_info.json @@ -4050,7 +4050,8 @@ "download_url": "https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10s.pt", "need_third_part": true, "precisions": [ - "fp16" + "fp16", + "int8" ], "type": "inference", "hasDemo": false, -- Gitee