SAM模型实战:用Python把Segment Anything变成你的私人‘抠图神器’
用Python玩转SAM模型零基础打造智能抠图工具设计师朋友小李最近接了个电商产品图精修的单子客户发来的原始照片背景杂乱无章。传统钢笔工具抠图需要反复调整锚点一张图就要耗费半小时。直到他发现了Meta发布的Segment Anything模型——这个能通过简单点击自动识别物体的AI工具让他的工作效率提升了十倍。本文将带你从零开始用Python代码将SAM模型变成你的私人抠图助手。1. 环境配置与模型准备1.1 基础环境搭建首先需要准备Python 3.8环境推荐使用Anaconda创建独立环境避免依赖冲突conda create -n sam_env python3.9 conda activate sam_env安装核心依赖库时要注意版本兼容性pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 pip install opencv-python matplotlib numpy segment-anything提示CUDA版本需要与显卡驱动匹配NVIDIA用户可通过nvidia-smi查看支持的CUDA版本1.2 模型文件获取SAM提供三种预训练模型根据硬件条件选择模型类型参数量显存占用适用场景vit_h636M8GB高精度专业级vit_l308M4-6GB平衡型vit_b91M2-3GB快速测试下载模型权重文件后如sam_vit_b_01ec64.pth建议存放在项目根目录的models文件夹中。2. 基础抠图功能实现2.1 单点精准抠图以下代码实现点击图片某处自动抠出目标物体import cv2 import numpy as np from segment_anything import sam_model_registry, SamPredictor def init_sam(model_pathmodels/sam_vit_b_01ec64.pth): sam sam_model_registry[vit_b](checkpointmodel_path) sam.to(devicecuda) return SamPredictor(sam) def single_point_cutout(image_path, point_x, point_y): image cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) predictor init_sam() predictor.set_image(image) input_point np.array([[point_x, point_y]]) input_label np.array([1]) # 1表示前景点 masks, _, _ predictor.predict( point_coordsinput_point, point_labelsinput_label, multimask_outputTrue ) return masks[0] # 返回置信度最高的掩膜实际应用时可以通过OpenCV的鼠标回调实现交互式点选def on_mouse_click(event, x, y, flags, param): if event cv2.EVENT_LBUTTONDOWN: mask single_point_cutout(product.jpg, x, y) show_result(mask)2.2 框选批量抠图对于轮廓清晰的物体矩形框选效率更高def box_cutout(image_path, x1, y1, x2, y2): image cv2.imread(image_path) predictor init_sam() predictor.set_image(image) input_box np.array([x1, y1, x2, y2]) masks, _, _ predictor.predict( boxinput_box[None, :], multimask_outputFalse ) return masks[0]电商产品图中常见的多物体抠图场景# 同时抠取图片中的手机和耳机 boxes [ [120, 80, 300, 400], # 手机坐标 [350, 200, 450, 350] # 耳机坐标 ] combined_mask np.zeros_like(image) for box in boxes: mask box_cutout(electronics.jpg, *box) combined_mask np.logical_or(combined_mask, mask)3. 高级抠图技巧3.1 复杂场景精修当目标物体与背景颜色相近时可以结合正负点提示def refine_cutout(image_path, box, positive_points, negative_points): image cv2.imread(image_path) predictor init_sam() predictor.set_image(image) # 合并正负点 all_points np.array(positive_points negative_points) all_labels np.array([1]*len(positive_points) [0]*len(negative_points)) masks, _, _ predictor.predict( point_coordsall_points, point_labelsall_labels, boxnp.array(box), multimask_outputFalse ) return masks[0]示例抠取玻璃杯中的液体glass_box [200, 150, 400, 500] liquid_points [[300, 300], [320, 280]] # 液体区域点 frame_points [[280, 180]] # 杯框干扰点 mask refine_cutout(glass.jpg, glass_box, liquid_points, frame_points)3.2 批量自动化处理结合对象检测模型实现全自动流水线处理from detectron2 import model_zoo from detectron2.engine import DefaultPredictor def auto_cutout_pipeline(image_path): # 第一步检测物体 detector DefaultPredictor(model_zoo.get(COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml)) detections detector(cv2.imread(image_path)) # 第二步SAM抠图 sam_predictor init_sam() image cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) sam_predictor.set_image(image) results [] for box in detections[instances].pred_boxes: masks, _, _ sam_predictor.predict( boxbox.cpu().numpy()[None, :], multimask_outputFalse ) results.append(masks[0]) return results4. 实用化功能扩展4.1 透明背景生成将抠图结果保存为PNG透明图片def save_transparent(image_path, mask, output_path): image cv2.imread(image_path, cv2.IMREAD_UNCHANGED) if image.shape[2] 3: image cv2.cvtColor(image, cv2.COLOR_BGR2BGRA) image[:, :, 3] mask * 255 # 设置alpha通道 cv2.imwrite(output_path, image)4.2 背景替换合成实现电商常见的场景切换效果def change_background(orig_path, mask, bg_path, output_path): foreground cv2.imread(orig_path) background cv2.imread(bg_path) # 调整背景尺寸匹配前景 background cv2.resize(background, (foreground.shape[1], foreground.shape[0])) # 合成图像 composite np.where(mask[..., None], foreground, background) cv2.imwrite(output_path, composite)4.3 批量处理工具开发用PyQt构建可视化操作界面from PyQt5.QtWidgets import (QApplication, QMainWindow, QFileDialog, QGraphicsScene) class SAMEditor(QMainWindow): def __init__(self): super().__init__() self.init_ui() self.sam_predictor init_sam() def open_image(self): path, _ QFileDialog.getOpenFileName() self.image cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) self.display_image(self.image) def mousePressEvent(self, event): if hasattr(self, image): x, y event.pos().x(), event.pos().y() mask self.single_point_cutout(x, y) self.show_mask(mask)实际测试中发现对于毛绒玩具等边缘模糊的物体适当增加负样本点背景点能显著提升分割精度。而在处理金属反光物体时矩形框选比点选效果更稳定。