这是一份基于图片特征算法检测图片相似度的python实践代码。
代码运行环境:
python2.7+OpenCV2.4.12
扩展包依赖:
numpy==1.11.0 PIL==1.1.7 cv2==2.4.12
# -*- coding: utf-8 -*- import os import time import pycurl from cStringIO import StringIO import cv2 import numpy from PIL import Image class FeatureBasedImageMatch(object): FLANN_INDEX_KDTREE = 1 # bug: flann enums are missing FLANN_INDEX_LSH = 6 """ 基于图片特征的匹配算法 """ def __init__(self, detector='SURF', matcher='FLANN'): """ :param detector: string SURF/SIFT/ORB/BRISK 计算特征所使用的算法,SURF比SIFT更快,但损失部分准确性 :param matcher: string FLANN/BF 特征匹配所使用的算法,BF为全面检测,FLANN为相邻匹配,数据量大的情况下FLANN更快 Description: 生成特征算法参考 1.http://docs.opencv.org/2.4/modules/nonfree/doc/feature_detection.html 2.http://docs.opencv.org/2.4/modules/features2d/doc/feature_detection_and_description.html 匹配算法选择参考 1.http://docs.opencv.org/2.4/modules/features2d/doc/common_interfaces_of_descriptor_matchers.html#BFMatcher%20:%20public%20DescriptorMatcher """ # 计算特征算法选择 if detector == 'SURF': self.detector = cv2.SURF(500, nOctaves=4, nOctaveLayers=2, extended=0, upright=1) norm = cv2.NORM_L2 elif detector == 'SIFT': self.detector = cv2.SIFT() norm = cv2.NORM_L2 elif detector == 'ORB': self.detector = cv2.ORB(400) norm = cv2.NORM_HAMMING elif detector == 'BRISK': self.detector = cv2.BRISK() norm = cv2.NORM_HAMMING else: raise FeatureBasedImageMatchException('Detector %s not support yet.' % detector) # 特征匹配算法选择 if matcher == 'FLANN': if norm == cv2.NORM_L2: flann_params = dict(algorithm=self.__class__.FLANN_INDEX_KDTREE, trees=5) else: flann_params = dict(algorithm=self.__class__.FLANN_INDEX_LSH, table_number=6, # 12 key_size=12, # 20 multi_probe_level=1) # 2 self.matcher = cv2.FlannBasedMatcher(flann_params, {}) # bug : need to pass empty dict (#1329) elif matcher == 'BF': self.matcher = cv2.BFMatcher(norm) else: raise FeatureBasedImageMatchException('Matcher %s not support yet.' % matcher) self.train_descriptors = [] def detect_and_compute(self, image=None, image_file=None): """ 探测及计算特征 """ if not isinstance(image, numpy.ndarray) and not image_file: raise FeatureBasedImageMatchException('Param image or image_file must needed.') if not isinstance(image, numpy.ndarray): image = ImageMatch.image_read(image_file) key_points, descriptors = self.detector.detectAndCompute(image, None) return key_points, descriptors def many_image_match(self, descriptors, ratio=0.75): """ 与匹配器中的特征匹配获取最高匹配特征值 """ matches = self.matcher.knnMatch(descriptors, k=2) return self.__class__.filter_matches(matches, ratio) def image_match(self, descriptors_1, descriptors_2, ratio=0.75): """ 与匹配器中的特征匹配获取最高匹配特征值 """ try: matches = self.matcher.knnMatch(descriptors_1, trainDescriptors=descriptors_2, k=2) except TypeError as e: print (e) return [] except cv2.error as e: print (e) return [] return self.__class__.filter_matches(matches, ratio) def image_match_explore(self, key_points_1, descriptors_1, key_points_2, descriptors_2, ratio=0.75): """ 单独两张图的特征匹配 """ matches = self.matcher.knnMatch(descriptors_1, trainDescriptors=descriptors_2, k=2) points_1, points_2, matches = ImageMatch.filter_matches(key_points_1, key_points_2, matches, ratio) return points_1, points_2, matches def add_matches_image(self, descriptors): """ 将图片特征添加到匹配器中 """ try: self.matcher.add([descriptors]) except TypeError: return False return True def clear_matches_image(self): """ 清空匹配器中的图片特征 """ self.matcher.clear() @staticmethod def filter_matches(matches, ratio=0.75): """ 筛选有效特征 """ return [m[0] for m in matches if len(m) == 2 and m[0].distance < m[1].distance * ratio] class FeatureBasedImageMatchException(Exception): pass class ImageMatch(object): """ 图片匹配的工具类 """ def __init__(self): self.file_dir = '/tmp/' self.file_ext = '.npy' @staticmethod def image_read(image_file=None, image_buffer=None): """ 载入图片 """ # 载入灰度图,OpenCV的图片读取只能使用本地文件路径,换成PIL读取,最终结果会有细微差别,忽略 if image_buffer: im = Image.open(StringIO(image_buffer)) else: if not os.path.isfile(image_file): c = pycurl.Curl() c.setopt(pycurl.URL, image_file) buf = StringIO() c.setopt(pycurl.WRITEFUNCTION, buf.write) c.setopt(pycurl.TIMEOUT, 5) try: c.perform() except pycurl.error as e: print e return False image_buffer = buf.getvalue() buf.close() c.close() if not image_buffer: return False im = Image.open(StringIO(image_buffer)) else: im = Image.open(image_file) image = numpy.asarray(im.convert('L')) return image @staticmethod def filter_matches(kp1, kp2, matches, ratio=0.75): """ 筛选匹配特征 """ mkp1, mkp2 = [], [] for m in matches: if len(m) == 2 and m[0].distance < m[1].distance * ratio: m = m[0] mkp1.append(kp1[m.queryIdx]) mkp2.append(kp2[m.trainIdx]) p1 = numpy.float32([kp.pt for kp in mkp1]) p2 = numpy.float32([kp.pt for kp in mkp2]) kp_pairs = zip(mkp1, mkp2) return p1, p2, list(kp_pairs) @staticmethod def explore_match(win, img1, img2, kp_pairs, status=None, H=None): """ 预览匹配到的特征 """ h1, w1 = img1.shape[:2] h2, w2 = img2.shape[:2] vis = numpy.zeros((max(h1, h2), w1 + w2), numpy.uint8) vis[:h1, :w1] = img1 vis[:h2, w1:w1 + w2] = img2 vis = cv2.cvtColor(vis, cv2.COLOR_GRAY2BGR) if H is not None: corners = numpy.float32([[0, 0], [w1, 0], [w1, h1], [0, h1]]) corners = numpy.int32(cv2.perspectiveTransform(corners.reshape(1, -1, 2), H).reshape(-1, 2) + (w1, 0)) cv2.polylines(vis, [corners], True, (255, 255, 255)) if status is None: status = numpy.ones(len(kp_pairs), numpy.bool_) p1, p2 = [], [] # python 2 / python 3 change of zip unpacking for kpp in kp_pairs: p1.append(numpy.int32(kpp[0].pt)) p2.append(numpy.int32(numpy.array(kpp[1].pt) + [w1, 0])) green = (0, 255, 0) red = (0, 0, 255) white = (255, 255, 255) kp_color = (51, 103, 236) for (x1, y1), (x2, y2), inlier in zip(p1, p2, status): if inlier: col = green cv2.circle(vis, (x1, y1), 2, col, -1) cv2.circle(vis, (x2, y2), 2, col, -1) else: col = red r = 2 thickness = 3 cv2.line(vis, (x1 - r, y1 - r), (x1 + r, y1 + r), col, thickness) cv2.line(vis, (x1 - r, y1 + r), (x1 + r, y1 - r), col, thickness) cv2.line(vis, (x2 - r, y2 - r), (x2 + r, y2 + r), col, thickness) cv2.line(vis, (x2 - r, y2 + r), (x2 + r, y2 - r), col, thickness) vis0 = vis.copy() for (x1, y1), (x2, y2), inlier in zip(p1, p2, status): if inlier: cv2.line(vis, (x1, y1), (x2, y2), green) cv2.imshow(win, vis) cv2.waitKey() cv2.destroyWindow(win) return vis def feature_save(self, picture_md5, feature): """ 保存特征 """ filename = self.get_feature_path(picture_md5) return numpy.save(filename, feature) def feature_load(self, picture_md5): """ 加载特征 """ filename = self.get_feature_path(picture_md5) if not filename.endswith('.npy'): filename += '.npy' return numpy.load(filename) def get_feature_path(self, picture_md5): """ 获取特征所在文件 """ return self.file_dir + picture_md5 + self.file_ext def feature_exist(self, picture_md5): """ 检测特征文件是否存在 """ filename = self.get_feature_path(picture_md5) if not filename.endswith('.npy'): filename += '.npy' return os.path.isfile(filename) def feature_remove(self, picture_md5): """ 删除特征文件 """ filename = self.get_feature_path(picture_md5) if not filename.endswith('.npy'): filename += '.npy' return os.remove(filename) class timer: """ 耗时计时器 """ def __init__(self, func=time.time): self.elapsed = 0.0 self._func = func self._start = None def start(self): if self._start is not None: raise RuntimeError('Already started') self._start = self._func() def stop(self): if self._start is None: raise RuntimeError('Not started') end = self._func() self.elapsed += end - self._start self._start = None def reset(self): self.elapsed = 0.0 @property def running(self): return self._start is not None def __enter__(self): self.start() return self def __exit__(self, *args): self.stop() if __name__ == '__main__': match_object = FeatureBasedImageMatch() # 样本图片 images = [ 'http://image-qzone.mamaquan.mama.cn/upload/2016/05/03/thumb_w196_bf6022f770f3b5f1eecc_w300X400_w192X256.jpg', ] # 待检测图片 match_image = 'http://image-qzone.mamaquan.mama.cn/upload/2016/05/03/thumb_w196_4ab517ac123cc9abffbe_w300X400_w192X256.jpg' match_image_object = ImageMatch.image_read(match_image) match_key_point, match_descriptors = match_object.detect_and_compute(image=match_image_object) image_object = dict() # 单独两张图的特征匹配 x = images[0] image_object[x] = dict() image_object[x]['image'] = ImageMatch.image_read(x) with timer() as t: image_object[x]['key_points'], image_object[x]['descriptors'] = match_object.detect_and_compute( image=image_object[x]['image']) print('Finish detect and compute image feature:%d and cost%f' % (len(image_object[x]['key_points']), t.elapsed)) with timer() as t: points_1, points_2, matches = match_object.image_match_explore(match_key_point, match_descriptors, image_object[x]['key_points'], image_object[x]['descriptors']) print('Finish image match,match image feature:%d and cost%f' % (len(matches), t.elapsed)) ImageMatch.explore_match('feature_match', match_image_object, image_object[x]['image'], matches) # 多张图片的特征匹配最大值 for x in images: image_object[x] = dict() image_object[x]['image'] = ImageMatch.image_read(x) with timer() as t: image_object[x]['key_points'], image_object[x]['descriptors'] = match_object.detect_and_compute(image=image_object[x]['image']) print('Finish detect and compute image feature:%d and cost%f' % (len(image_object[x]['key_points']), t.elapsed)) with timer() as t: match_object.add_matches_image(image_object[x]['descriptors']) print('Finish add matches image and cost%f' % t.elapsed) with timer() as t: matches = match_object.many_image_match(match_descriptors) print('Finish image match,match image feature:%d and cost%f' % (len(matches), t.elapsed))
你也可以在 Github上找到它。