Python結(jié)合FastSAM實現(xiàn)圖像自動標注的完整指南
在計算機視覺領(lǐng)域,數(shù)據(jù)標注是模型訓(xùn)練的基礎(chǔ),但手動標注耗時耗力。本文將介紹一個基于Python的自動標注工具,它結(jié)合了FastSAM(快速分割任何東西模型)和YOLO分類模型,能夠高效地生成高質(zhì)量的標注數(shù)據(jù)。
概述
FastSAM是SAM(Segment Anything Model)的加速版本,能夠在保持較高精度的同時大幅提升處理速度。我們的自動標注工具利用FastSAM進行對象檢測和分割,再通過YOLO模型對檢測到的對象進行分類,最終輸出YOLO格式的標注文件。
本文將詳細解析代碼結(jié)構(gòu)、實現(xiàn)原理和使用方法,幫助讀者快速掌握圖像自動標注的核心技術(shù)。
環(huán)境準備
在開始之前,需要安裝必要的Python庫。建議使用Python 3.7或更高版本。
pip install torch torchvision pip install opencv-python pillow pip install ultralytics pip install numpy tqdm pathlib argparse
確保已下載FastSAM模型文件(如FastSAM-x.pt)和YOLO分類模型(如yolov8n-cls.pt)。
項目結(jié)構(gòu)
fastsam_autolabeler/
├── autolabeler.py # 主程序文件
├── models/ # 模型目錄
│ ├── FastSAM-x.pt
│ └── yolov8n-cls.pt
├── images/ # 輸入圖片目錄
├── dataset/ # 輸出標注數(shù)據(jù)
│ ├── images/
│ ├── labels/
│ └── visualization/
└── README.md
核心代碼解析
1. 主類FastSAMAutoLabeler
FastSAMAutoLabeler類是自動標注工具的核心,負責(zé)協(xié)調(diào)整個標注流程。
class FastSAMAutoLabeler:
def __init__(self, model_path='FastSAM-x.pt', device='cuda' if torch.cuda.is_available() else 'cpu', classification_model='yolov8n-cls.pt'):
self.device = device
self.model = FastSAM(model_path)
self.classification_model = YOLO(classification_model)
初始化過程會加載兩個模型:FastSAM用于對象檢測和分割,YOLO用于對象分類。代碼自動檢測可用的計算設(shè)備,優(yōu)先使用GPU加速處理。
2. 圖像處理流程
process_image方法是主要的處理管道,包含以下步驟:
def process_image(self, image_path, output_dir, conf=0.4, iou=0.9, min_area_ratio=0.001, max_area_ratio=0.95):
# 讀取圖片
image = cv2.imread(image_path)
height, width = image.shape[:2]
# FastSAM推理
everything_results = self.model(image_path, device=self.device, retina_masks=True, imgsz=1024, conf=conf, iou=iou)
# 處理掩碼數(shù)據(jù)
detections = self._process_masks_manually(ann, everything_results, width, height)
# 過濾檢測結(jié)果
filtered_detections = self._filter_detections(detections, image_area, min_area_ratio, max_area_ratio)
# 對象分類
classified_detections = self._classify_objects(image, filtered_detections)
# 生成標注
return self._generate_annotations(image, classified_detections, output_dir, Path(image_path).stem)
此方法完整實現(xiàn)了從圖像讀取到標注生成的整個流程,每個步驟都設(shè)計了適當(dāng)?shù)腻e誤處理機制。
3. 掩碼處理與邊界框提取
_process_masks_manually方法將FastSAM輸出的分割掩碼轉(zhuǎn)換為邊界框:
def _process_masks_manually(self, ann, everything_results, img_width, img_height):
masks_np = ann.cpu().numpy()
boxes = []
for i in range(num_masks):
mask = masks_np[i]
y_indices, x_indices = np.where(mask > 0.5) # 閾值處理
# 計算邊界框
x1 = np.min(x_indices)
y1 = np.min(y_indices)
x2 = np.max(x_indices)
y2 = np.max(y_indices)
boxes.append([x1, y1, x2, y2])
這種方法不依賴額外的計算機視覺庫,實現(xiàn)了自包含的掩碼處理功能。
4. 對象分類
_classify_objects方法對每個檢測到的對象進行分類:
def _classify_objects(self, image, detections):
for i, bbox in enumerate(detections['boxes']):
x1, y1, x2, y2 = map(int, bbox)
object_image = image[y1:y2, x1:x2]
object_image_resized = cv2.resize(object_image, (224, 224))
# 使用YOLO分類模型
results = self.classification_model(object_image_resized)
top1 = results[0].probs.top1
top1conf = results[0].probs.top1conf.item()
通過結(jié)合實例分割和分類模型,工具能夠準確識別和分類圖像中的各個對象。
輔助工具類
1. 標注可視化器
ManualAnnotationVisualizer類提供標注結(jié)果的可視化功能:
class ManualAnnotationVisualizer:
def draw_annotations(self, image_path, label_path=None, detections=None, output_path=None):
# 繪制邊界框和標簽
for i, bbox in enumerate(detections['boxes']):
color = self.colors[i % len(self.colors)]
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
可視化結(jié)果使用不同顏色區(qū)分各個檢測對象,并顯示類別標簽和置信度。
2. 標注驗證器
AnnotationValidator類檢查標注文件的質(zhì)量:
class AnnotationValidator:
def validate_annotations(self, image_path, label_path):
# 檢查數(shù)值范圍、邊界框有效性、重疊等
issues = []
for i, line in enumerate(lines):
# 驗證每個標注行的格式和數(shù)值
if not (0 <= x_center <= 1):
issues.append(f"第{i+1}行x_center超出范圍 [0,1]: {x_center}")
驗證器幫助用戶發(fā)現(xiàn)標注中的問題,確保生成的數(shù)據(jù)集質(zhì)量。
使用方法
命令行參數(shù)
工具支持豐富的命令行參數(shù),滿足不同場景的需求:
python autolabeler.py \
--input images/ \
--output dataset/ \
--model models/FastSAM-x.pt \
--classification-model models/yolov8n-cls.pt \
--conf 0.4 \
--iou 0.9 \
--visualize \
--validate
主要參數(shù)包括:
--input: 輸入圖片路徑(文件或目錄)--output: 輸出目錄--conf: 檢測置信度閾值--iou: 非極大值抑制IOU閾值--visualize: 生成可視化結(jié)果--validate: 驗證標注質(zhì)量
批量處理
工具支持單張圖片和批量處理模式。當(dāng)輸入為目錄時,會自動遍歷所有支持格式的圖片文件:
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
for img_path in tqdm(image_files, desc="處理圖片"):
result = labeler.process_image(str(img_path), args.output, args.conf, args.iou)
批量處理時顯示進度條,方便用戶了解處理進度。
擴展功能
1. 自定義過濾規(guī)則
用戶可以調(diào)整檢測結(jié)果的過濾條件,如基于對象面積的比例:
min_area_ratio=0.001 # 最小面積比例(相對于圖像面積) max_area_ratio=0.95 # 最大面積比例
這有助于過濾掉過小或過大的檢測結(jié)果,提高標注質(zhì)量。
2. 多類別支持
通過替換YOLO分類模型,工具可以適應(yīng)不同的領(lǐng)域和類別需求。例如,使用針對特定場景訓(xùn)練的專用分類器。
完整代碼
import torch
import cv2
import numpy as np
from PIL import Image
from ultralytics import FastSAM, YOLO
import os
import json
from pathlib import Path
import argparse
import glob
from tqdm import tqdm
import math
class FastSAMAutoLabeler:
def __init__(self, model_path='FastSAM-x.pt', device='cuda' if torch.cuda.is_available() else 'cpu', classification_model='yolov8n-cls.pt'):
"""
FastSAM自動標注工具(不依賴supervision)
"""
self.device = device
self.model = FastSAM(model_path)
# 加載分類模型
self.classification_model = YOLO(classification_model)
print(f"模型加載完成,使用設(shè)備: {device}")
def process_image(self, image_path, output_dir, conf=0.4, iou=0.9,
min_area_ratio=0.001, max_area_ratio=0.95):
"""
處理單張圖片并生成標注
"""
try:
# 讀取圖片
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"無法讀取圖片: {image_path}")
height, width = image.shape[:2]
image_area = height * width
print(f"圖片尺寸: {width}x{height}, 面積: {image_area}")
# 使用FastSAM進行推理
everything_results = self.model(
image_path,
device=self.device,
retina_masks=True,
imgsz=1024,
conf=conf,
iou=iou
)
# 檢查是否有檢測結(jié)果
if not everything_results or len(everything_results) == 0:
print("警告: 未獲得任何檢測結(jié)果")
return self._create_empty_result(image, output_dir, Path(image_path).stem)
# 獲取掩碼數(shù)據(jù)
masks_data = everything_results[0].masks
if masks_data is None:
print("警告: 未檢測到任何掩碼")
return self._create_empty_result(image, output_dir, Path(image_path).stem)
ann = masks_data.data
print(f"原始掩碼形狀: {ann.shape}")
# 處理掩碼維度
if len(ann.shape) == 2:
ann = ann.unsqueeze(0)
elif len(ann.shape) != 3:
raise ValueError(f"不支持的掩碼形狀: {ann.shape}")
# 手動處理檢測結(jié)果
detections = self._process_masks_manually(ann, everything_results, width, height)
# 過濾檢測結(jié)果
filtered_detections = self._filter_detections(
detections, image_area, min_area_ratio, max_area_ratio
)
# 對每個檢測到的對象進行分類
classified_detections = self._classify_objects(image, filtered_detections)
return self._generate_annotations(image, classified_detections, output_dir, Path(image_path).stem)
except Exception as e:
print(f"處理圖片時發(fā)生錯誤: {e}")
import traceback
traceback.print_exc()
return self._create_error_result(image, output_dir, Path(image_path).stem, str(e))
def _process_masks_manually(self, ann, everything_results, img_width, img_height):
"""
手動處理掩碼數(shù)據(jù),替代supervision的功能
"""
# 將張量轉(zhuǎn)換為numpy數(shù)組
if isinstance(ann, torch.Tensor):
masks_np = ann.cpu().numpy()
else:
masks_np = ann
print(f"掩碼numpy數(shù)組形狀: {masks_np.shape}")
# 確保是三維的 [N, H, W]
if len(masks_np.shape) == 2:
masks_np = np.expand_dims(masks_np, axis=0)
num_masks = masks_np.shape[0]
print(f"檢測到 {num_masks} 個掩碼")
if num_masks == 0:
return self._create_empty_detections()
# 為每個掩碼計算邊界框和相關(guān)信息
boxes = []
confidences = []
class_ids = []
masks = []
for i in range(num_masks):
mask = masks_np[i]
# 找到掩碼中為True的像素位置
y_indices, x_indices = np.where(mask > 0.5) # 閾值處理
if len(x_indices) == 0 or len(y_indices) == 0:
continue
# 計算邊界框
x1 = np.min(x_indices)
y1 = np.min(y_indices)
x2 = np.max(x_indices)
y2 = np.max(y_indices)
# 計算面積和置信度(使用掩碼面積作為置信度參考)
bbox_area = (x2 - x1) * (y2 - y1)
mask_area = len(x_indices)
confidence = min(mask_area / bbox_area, 1.0) if bbox_area > 0 else 0
boxes.append([x1, y1, x2, y2])
confidences.append(confidence)
class_ids.append(0) # 默認類別ID
masks.append(mask)
if not boxes:
return self._create_empty_detections()
return {
'boxes': np.array(boxes),
'confidences': np.array(confidences),
'class_ids': np.array(class_ids),
'masks': np.array(masks)
}
def _filter_detections(self, detections, image_area, min_area_ratio, max_area_ratio):
"""根據(jù)面積過濾檢測結(jié)果"""
if len(detections['boxes']) == 0:
return detections
filtered_boxes = []
filtered_confidences = []
filtered_class_ids = []
filtered_masks = []
for i, bbox in enumerate(detections['boxes']):
x1, y1, x2, y2 = bbox
area = (x2 - x1) * (y2 - y1)
area_ratio = area / image_area
if min_area_ratio <= area_ratio <= max_area_ratio:
filtered_boxes.append(bbox)
filtered_confidences.append(detections['confidences'][i])
filtered_class_ids.append(detections['class_ids'][i])
if i < len(detections['masks']):
filtered_masks.append(detections['masks'][i])
filtered_detections = {
'boxes': np.array(filtered_boxes) if filtered_boxes else np.empty((0, 4)),
'confidences': np.array(filtered_confidences) if filtered_confidences else np.empty(0),
'class_ids': np.array(filtered_class_ids) if filtered_class_ids else np.empty(0),
'masks': np.array(filtered_masks) if filtered_masks else np.empty(0)
}
print(f"過濾后保留 {len(filtered_boxes)} 個檢測結(jié)果")
return filtered_detections
def _classify_objects(self, image, detections):
"""
對檢測到的對象進行分類
"""
if len(detections['boxes']) == 0:
return detections
classified_class_ids = []
classified_confidences = []
for i, bbox in enumerate(detections['boxes']):
x1, y1, x2, y2 = map(int, bbox)
# 提取對象區(qū)域
object_image = image[y1:y2, x1:x2]
if object_image.size == 0:
classified_class_ids.append(0)
classified_confidences.append(detections['confidences'][i])
continue
# 調(diào)整圖像大小以適應(yīng)分類模型
object_image_resized = cv2.resize(object_image, (224, 224))
# 使用分類模型進行預(yù)測
try:
results = self.classification_model(object_image_resized)
# 獲取最高置信度的類別
top1 = results[0].probs.top1
top1conf = results[0].probs.top1conf.item()
classified_class_ids.append(top1)
classified_confidences.append(top1conf)
except Exception as e:
print(f"分類時出錯: {e}")
# 如果分類失敗,保持原始類別
classified_class_ids.append(detections['class_ids'][i])
classified_confidences.append(detections['confidences'][i])
# 更新檢測結(jié)果
detections['class_ids'] = np.array(classified_class_ids)
detections['confidences'] = np.array(classified_confidences)
return detections
def _generate_annotations(self, image, detections, output_dir, image_name):
"""生成YOLO格式標注文件"""
height, width = image.shape[:2]
# 創(chuàng)建輸出目錄
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'labels'), exist_ok=True)
# 保存圖片
image_output_path = os.path.join(output_dir, 'images', f'{image_name}.jpg')
cv2.imwrite(image_output_path, image)
# 生成YOLO格式標注
yolo_annotations = []
for i, bbox in enumerate(detections['boxes']):
x1, y1, x2, y2 = bbox
# 轉(zhuǎn)換為YOLO格式 (中心點坐標和寬高,歸一化)
x_center = ((x1 + x2) / 2) / width
y_center = ((y1 + y2) / 2) / height
w = (x2 - x1) / width
h = (y2 - y1) / height
# 邊界檢查
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
w = max(0, min(1, w))
h = max(0, min(1, h))
# 如果寬高太小則跳過
if w < 0.001 or h < 0.001:
continue
# 獲取類別ID和置信度
class_id = int(detections['class_ids'][i]) if i < len(detections['class_ids']) else 0
confidence = detections['confidences'][i] if i < len(detections['confidences']) else 1.0
yolo_annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}")
# 保存YOLO標簽
label_path = os.path.join(output_dir, 'labels', f'{image_name}.txt')
with open(label_path, 'w') as f:
f.write('\n'.join(yolo_annotations))
return {
'image_path': image_output_path,
'label_path': label_path,
'detections_count': len(yolo_annotations),
'image_name': image_name
}
def _create_empty_detections(self):
"""創(chuàng)建空的檢測結(jié)果"""
return {
'boxes': np.empty((0, 4)),
'confidences': np.empty(0),
'class_ids': np.empty(0),
'masks': np.empty(0)
}
def _create_empty_result(self, image, output_dir, image_name):
"""創(chuàng)建空結(jié)果"""
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'labels'), exist_ok=True)
image_output_path = os.path.join(output_dir, 'images', f'{image_name}.jpg')
cv2.imwrite(image_output_path, image)
label_path = os.path.join(output_dir, 'labels', f'{image_name}.txt')
with open(label_path, 'w') as f:
pass
return {
'image_path': image_output_path,
'label_path': label_path,
'detections_count': 0,
'image_name': image_name
}
def _create_error_result(self, image, output_dir, image_name, error_msg):
"""創(chuàng)建錯誤結(jié)果"""
print(f"為圖片 {image_name} 創(chuàng)建錯誤結(jié)果: {error_msg}")
return self._create_empty_result(image, output_dir, image_name)
class ManualAnnotationVisualizer:
"""
手動實現(xiàn)的標注可視化工具(不依賴supervision)
"""
def __init__(self, class_names=None, colors=None):
self.class_names = class_names or ['object']
self.colors = colors or self._generate_default_colors()
def _generate_default_colors(self):
"""生成默認顏色列表"""
return [
(255, 0, 0), # 紅色
(0, 255, 0), # 綠色
(0, 0, 255), # 藍色
(255, 255, 0), # 青色
(255, 0, 255), # 紫色
(0, 255, 255), # 黃色
(255, 165, 0), # 橙色
(128, 0, 128), # 紫色
(255, 192, 203), # 粉色
(165, 42, 42), # 棕色
]
def draw_annotations(self, image_path, label_path=None, detections=None,
output_path=None, show_labels=True, show_confidences=True):
"""
繪制標注結(jié)果
"""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"無法讀取圖片: {image_path}")
height, width = image.shape[:2]
if detections is None and label_path is not None:
# 從YOLO標簽文件讀取檢測結(jié)果
detections = self._read_yolo_labels(label_path, width, height)
elif detections is None:
raise ValueError("必須提供label_path或detections參數(shù)")
# 繪制邊界框和標簽
annotated_image = self._draw_bounding_boxes(image, detections, show_labels, show_confidences)
if output_path:
cv2.imwrite(output_path, annotated_image)
print(f"可視化結(jié)果已保存: {output_path}")
return annotated_image
def _draw_bounding_boxes(self, image, detections, show_labels, show_confidences):
"""繪制邊界框和標簽"""
annotated_image = image.copy()
for i, bbox in enumerate(detections['boxes']):
x1, y1, x2, y2 = map(int, bbox)
# 選擇顏色
color = self.colors[i % len(self.colors)]
# 繪制邊界框
cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2)
if show_labels:
# 準備標簽文本
class_id = int(detections['class_ids'][i]) if i < len(detections['class_ids']) else 0
class_name = self.class_names[class_id] if class_id < len(self.class_names) else f'class_{class_id}'
label = class_name
if show_confidences and i < len(detections['confidences']):
confidence = detections['confidences'][i]
label += f" {confidence:.2f}"
# 繪制標簽背景
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
cv2.rectangle(annotated_image,
(x1, y1 - label_size[1] - 10),
(x1 + label_size[0], y1),
color, -1)
# 繪制標簽文本
cv2.putText(annotated_image, label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
return annotated_image
def _read_yolo_labels(self, label_path, img_width, img_height):
"""讀取YOLO格式標簽并轉(zhuǎn)換為檢測結(jié)果格式"""
boxes = []
class_ids = []
confidences = []
if not os.path.exists(label_path):
return self._create_empty_detections()
with open(label_path, 'r') as f:
for label_line in f:
parts = label_line.strip().split()
if len(parts) < 5:
continue
class_id = int(parts[0])
x_center, y_center, width, height = map(float, parts[1:5])
# 轉(zhuǎn)換為絕對坐標
x_center_abs = x_center * img_width
y_center_abs = y_center * img_height
width_abs = width * img_width
height_abs = height * img_height
x1 = max(0, x_center_abs - width_abs / 2)
y1 = max(0, y_center_abs - height_abs / 2)
x2 = min(img_width, x_center_abs + width_abs / 2)
y2 = min(img_height, y_center_abs + height_abs / 2)
boxes.append([x1, y1, x2, y2])
class_ids.append(class_id)
confidences.append(1.0) # YOLO格式?jīng)]有置信度,設(shè)為1.0
return {
'boxes': np.array(boxes) if boxes else np.empty((0, 4)),
'confidences': np.array(confidences) if confidences else np.empty(0),
'class_ids': np.array(class_ids) if class_ids else np.empty(0),
'masks': np.empty(0)
}
def _create_empty_detections(self):
"""創(chuàng)建空的檢測結(jié)果"""
return {
'boxes': np.empty((0, 4)),
'confidences': np.empty(0),
'class_ids': np.empty(0),
'masks': np.empty(0)
}
class AnnotationValidator:
"""
標注驗證器,用于檢查標注質(zhì)量和提供統(tǒng)計信息
"""
def __init__(self, class_names=None):
self.class_names = class_names or ['object']
def validate_annotations(self, image_path, label_path):
"""
驗證標注文件的質(zhì)量
"""
# 讀取圖像
image = cv2.imread(image_path)
if image is None:
return {"error": f"無法讀取圖像: {image_path}"}
height, width = image.shape[:2]
# 讀取標注文件
if not os.path.exists(label_path):
return {"error": f"標注文件不存在: {label_path}"}
with open(label_path, 'r') as f:
lines = f.readlines()
if not lines:
return {"warning": "標注文件為空"}
issues = []
class_counts = {}
boxes = []
for i, line in enumerate(lines):
parts = line.strip().split()
if len(parts) < 5:
issues.append(f"第{i+1}行格式錯誤: 需要至少5個值,實際得到{len(parts)}個")
continue
try:
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
w = float(parts[3])
h = float(parts[4])
# 檢查數(shù)值范圍
if not (0 <= x_center <= 1):
issues.append(f"第{i+1}行x_center超出范圍 [0,1]: {x_center}")
if not (0 <= y_center <= 1):
issues.append(f"第{i+1}行y_center超出范圍 [0,1]: {y_center}")
if not (0 <= w <= 1):
issues.append(f"第{i+1}行width超出范圍 [0,1]: {w}")
if not (0 <= h <= 1):
issues.append(f"第{i+1}行height超出范圍 [0,1]: {h}")
# 檢查邊界框是否有效
if w <= 0 or h <= 0:
issues.append(f"第{i+1}行邊界框尺寸無效: width={w}, height={h}")
# 統(tǒng)計類別
class_counts[class_id] = class_counts.get(class_id, 0) + 1
# 轉(zhuǎn)換為像素坐標用于重疊檢查
x1 = max(0, (x_center - w/2) * width)
y1 = max(0, (y_center - h/2) * height)
x2 = min(width, (x_center + w/2) * width)
y2 = min(height, (y_center + h/2) * height)
boxes.append((x1, y1, x2, y2))
except ValueError as e:
issues.append(f"第{i+1}行數(shù)值轉(zhuǎn)換錯誤: {str(e)}")
# 檢查重疊的邊界框
overlapping_boxes = self._check_overlapping_boxes(boxes)
if overlapping_boxes:
issues.append(f"發(fā)現(xiàn){len(overlapping_boxes)}對重疊的邊界框")
# 生成報告
report = {
"total_objects": len(lines),
"class_distribution": class_counts,
"issues": issues,
"image_size": (width, height)
}
if class_counts:
# 添加類別名稱映射
class_names_mapping = {}
for class_id in class_counts:
if class_id < len(self.class_names):
class_names_mapping[class_id] = self.class_names[class_id]
else:
class_names_mapping[class_id] = f"未知類別_{class_id}"
report["class_names"] = class_names_mapping
return report
def _check_overlapping_boxes(self, boxes, overlap_threshold=0.5):
"""
檢查重疊的邊界框
"""
overlapping = []
for i in range(len(boxes)):
for j in range(i+1, len(boxes)):
x1_a, y1_a, x2_a, y2_a = boxes[i]
x1_b, y1_b, x2_b, y2_b = boxes[j]
# 計算交集
x_left = max(x1_a, x1_b)
y_top = max(y1_a, y1_b)
x_right = min(x2_a, x2_b)
y_bottom = min(y2_a, y2_b)
if x_right > x_left and y_bottom > y_top:
# 計算交集面積
intersection_area = (x_right - x_left) * (y_bottom - y_top)
# 計算兩個框的面積
area_a = (x2_a - x1_a) * (y2_a - y1_a)
area_b = (x2_b - x1_b) * (y2_b - y1_b)
# 計算重疊率
overlap = intersection_area / min(area_a, area_b)
if overlap > overlap_threshold:
overlapping.append((i, j, overlap))
return overlapping
def main():
"""主函數(shù)示例"""
parser = argparse.ArgumentParser(description='FastSAM自動標注工具(無依賴版)')
parser.add_argument('--input', type=str,default="images", help='輸入圖片目錄或文件路徑')
parser.add_argument('--output', type=str, default="dataset", help='輸出目錄')
parser.add_argument('--model', type=str, default='FastSAM-x.pt', help='FastSAM模型路徑')
parser.add_argument('--classification-model', type=str, default='yolov8x-cls.pt', help='分類模型路徑')
parser.add_argument('--conf', type=float, default=0.4, help='置信度閾值')
parser.add_argument('--iou', type=float, default=0.9, help='IOU閾值')
parser.add_argument('--min-area', type=float, default=0.001, help='最小面積比例')
parser.add_argument('--max-area', type=float, default=0.95, help='最大面積比例')
parser.add_argument('--visualize', action='store_true', help='是否生成可視化結(jié)果')
parser.add_argument('--validate', action='store_true', help='是否驗證標注結(jié)果')
args = parser.parse_args()
# 創(chuàng)建輸出目錄
os.makedirs(args.output, exist_ok=True)
# 初始化標注器
labeler = FastSAMAutoLabeler(args.model, classification_model=args.classification_model)
# 處理輸入
if os.path.isfile(args.input):
# 單文件處理
result = labeler.process_image(
args.input, args.output, args.conf, args.iou, args.min_area, args.max_area
)
print(f"處理完成: {result}")
if args.visualize:
# 獲取分類模型的類別名稱
class_names = labeler.classification_model.names if hasattr(labeler.classification_model, 'names') else None
visualizer = ManualAnnotationVisualizer(class_names=class_names)
vis_path = os.path.join(args.output, 'visualization', f"{Path(args.input).stem}_annotated.jpg")
os.makedirs(os.path.dirname(vis_path), exist_ok=True)
visualizer.draw_annotations(
result['image_path'], result['label_path'], output_path=vis_path
)
# 如果需要驗證,執(zhí)行驗證
if args.validate:
class_names = labeler.classification_model.names if hasattr(labeler.classification_model, 'names') else None
validator = AnnotationValidator(class_names=class_names)
validation_report = validator.validate_annotations(result['image_path'], result['label_path'])
print("\n標注驗證報告:")
print("=" * 50)
if "error" in validation_report:
print(f"錯誤: {validation_report['error']}")
elif "warning" in validation_report:
print(f"警告: {validation_report['warning']}")
else:
print(f"圖像尺寸: {validation_report['image_size'][0]}x{validation_report['image_size'][1]}")
print(f"總對象數(shù): {validation_report['total_objects']}")
print("\n類別分布:")
for class_id, count in validation_report['class_distribution'].items():
class_name = validation_report['class_names'].get(class_id, f"未知類別_{class_id}")
print(f" {class_name} ({class_id}): {count}個")
if validation_report['issues']:
print(f"\n發(fā)現(xiàn)問題 ({len(validation_report['issues'])}個):")
for issue in validation_report['issues']:
print(f" - {issue}")
else:
print("\n標注質(zhì)量良好,未發(fā)現(xiàn)問題。")
else:
# 目錄處理
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
image_files = []
for ext in image_extensions:
pattern = f'**/*{ext}'
image_files.extend(Path(args.input).glob(pattern))
image_files.extend(Path(args.input).glob(pattern.upper()))
image_files = list(set(image_files))
print(f"找到 {len(image_files)} 張圖片")
results = []
successful = 0
failed = 0
for img_path in tqdm(image_files, desc="處理圖片"):
try:
result = labeler.process_image(
str(img_path), args.output, args.conf, args.iou, args.min_area, args.max_area
)
results.append(result)
successful += 1
print(f"? 成功處理: {img_path.name} (檢測到 {result['detections_count']} 個對象)")
except Exception as e:
failed += 1
print(f"? 處理失敗: {img_path.name} - 錯誤: {e}")
print(f"\n處理完成!")
print(f"成功: {successful}, 失敗: {failed}")
if args.visualize and results:
# 為前幾張圖片生成可視化結(jié)果
# 獲取分類模型的類別名稱
class_names = labeler.classification_model.names if hasattr(labeler.classification_model, 'names') else None
visualizer = ManualAnnotationVisualizer(class_names=class_names)
vis_dir = os.path.join(args.output, 'visualization')
os.makedirs(vis_dir, exist_ok=True)
sample_count = min(5, len(results))
print(f"\n為前 {sample_count} 張圖片生成可視化結(jié)果...")
for i, result in enumerate(results[:sample_count]):
vis_path = os.path.join(vis_dir, f"{result['image_name']}_annotated.jpg")
visualizer.draw_annotations(
result['image_path'], result['label_path'], output_path=vis_path
)
# 如果需要驗證,執(zhí)行驗證
if args.validate and results:
print("\n開始驗證標注結(jié)果...")
class_names = labeler.classification_model.names if hasattr(labeler.classification_model, 'names') else None
validator = AnnotationValidator(class_names=class_names)
total_issues = 0
for result in results:
validation_report = validator.validate_annotations(result['image_path'], result['label_path'])
if "issues" in validation_report and validation_report["issues"]:
total_issues += len(validation_report["issues"])
print(f"\n{result['image_name']}發(fā)現(xiàn)問題:")
for issue in validation_report["issues"]:
print(f" - {issue}")
if total_issues == 0:
print("所有標注文件驗證通過,未發(fā)現(xiàn)問題。")
else:
print(f"\n總共發(fā)現(xiàn) {total_issues} 個問題。")
if __name__ == "__main__":
main()
總結(jié)與展望
本文介紹的FastSAM自動標注工具展示了如何將先進的計算機視覺模型應(yīng)用于實際數(shù)據(jù)標注任務(wù)。其主要優(yōu)勢包括:
- 高效性:結(jié)合FastSAM的快速分割和YOLO的準確分類,大幅提升標注效率
- 靈活性:支持參數(shù)調(diào)整和自定義過濾規(guī)則,適應(yīng)不同場景需求
- 質(zhì)量保證:內(nèi)置驗證和可視化功能,確保標注數(shù)據(jù)質(zhì)量
- 易用性:簡單的命令行接口,支持批量處理
未來可能的改進方向包括:
- 支持更多標注格式(如COCO、Pascal VOC)
- 添加交互式修正界面
- 集成主動學(xué)習(xí)策略,優(yōu)先標注不確定性高的樣本
- 優(yōu)化模型推理速度,支持實時標注
這個工具不僅適用于學(xué)術(shù)研究,也可用于工業(yè)界的實際項目,為計算機視覺模型訓(xùn)練提供高質(zhì)量的數(shù)據(jù)支持。通過本文的詳細解析,讀者可以深入了解實現(xiàn)原理,并根據(jù)自身需求進行定制化開發(fā)。
希望本指南能幫助您更高效地處理圖像標注任務(wù),歡迎在實踐中進一步探索和優(yōu)化這個工具。
以上就是Python結(jié)合FastSAM實現(xiàn)圖像自動標注的完整指南的詳細內(nèi)容,更多關(guān)于Python FastSAM圖像自動標注的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python動畫manim中的顏色ManimColor的使用方法詳解
這篇文章主要介紹了python動畫manim中的顏色ManimColor的使用方法,本文通過實例圖文展示給大家介紹的非常詳細,感興趣的朋友跟隨小編一起看看吧2024-08-08
python for循環(huán)如何實現(xiàn)控制步長
這篇文章主要介紹了python for循環(huán)如何實現(xiàn)控制步長,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-05-05
對sklearn的使用之?dāng)?shù)據(jù)集的拆分與訓(xùn)練詳解(python3.6)
今天小編就為大家分享一篇對sklearn的使用之?dāng)?shù)據(jù)集的拆分與訓(xùn)練詳解(python3.6),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-12

