这是一份基于图片特征算法检测图片相似度的python实践代码。

代码运行环境:

python2.7+OpenCV2.4.12

扩展包依赖:

numpy==1.11.0
PIL==1.1.7
cv2==2.4.12
# -*- coding: utf-8 -*-

import os
import time
import pycurl
from cStringIO import StringIO

import cv2
import numpy
from PIL import Image


class FeatureBasedImageMatch(object):
    FLANN_INDEX_KDTREE = 1  # bug: flann enums are missing
    FLANN_INDEX_LSH = 6

    """ 基于图片特征的匹配算法 """
    def __init__(self, detector='SURF', matcher='FLANN'):
        """
        :param detector: string SURF/SIFT/ORB/BRISK 计算特征所使用的算法,SURF比SIFT更快,但损失部分准确性
        :param matcher: string FLANN/BF 特征匹配所使用的算法,BF为全面检测,FLANN为相邻匹配,数据量大的情况下FLANN更快
        Description:
        生成特征算法参考
        1.http://docs.opencv.org/2.4/modules/nonfree/doc/feature_detection.html
        2.http://docs.opencv.org/2.4/modules/features2d/doc/feature_detection_and_description.html
        匹配算法选择参考
        1.http://docs.opencv.org/2.4/modules/features2d/doc/common_interfaces_of_descriptor_matchers.html#BFMatcher%20:%20public%20DescriptorMatcher
        """
        # 计算特征算法选择
        if detector == 'SURF':
            self.detector = cv2.SURF(500, nOctaves=4, nOctaveLayers=2, extended=0, upright=1)
            norm = cv2.NORM_L2
        elif detector == 'SIFT':
            self.detector = cv2.SIFT()
            norm = cv2.NORM_L2
        elif detector == 'ORB':
            self.detector = cv2.ORB(400)
            norm = cv2.NORM_HAMMING
        elif detector == 'BRISK':
            self.detector = cv2.BRISK()
            norm = cv2.NORM_HAMMING
        else:
            raise FeatureBasedImageMatchException('Detector %s not support yet.' % detector)
        # 特征匹配算法选择
        if matcher == 'FLANN':
            if norm == cv2.NORM_L2:
                flann_params = dict(algorithm=self.__class__.FLANN_INDEX_KDTREE, trees=5)
            else:
                flann_params = dict(algorithm=self.__class__.FLANN_INDEX_LSH,
                                    table_number=6,  # 12
                                    key_size=12,  # 20
                                    multi_probe_level=1)  # 2
            self.matcher = cv2.FlannBasedMatcher(flann_params, {})  # bug : need to pass empty dict (#1329)
        elif matcher == 'BF':
            self.matcher = cv2.BFMatcher(norm)
        else:
            raise FeatureBasedImageMatchException('Matcher %s not support yet.' % matcher)

        self.train_descriptors = []

    def detect_and_compute(self, image=None, image_file=None):
        """ 探测及计算特征 """
        if not isinstance(image, numpy.ndarray) and not image_file:
            raise FeatureBasedImageMatchException('Param image or image_file must needed.')

        if not isinstance(image, numpy.ndarray):
            image = ImageMatch.image_read(image_file)

        key_points, descriptors = self.detector.detectAndCompute(image, None)
        return key_points, descriptors

    def many_image_match(self, descriptors, ratio=0.75):
        """ 与匹配器中的特征匹配获取最高匹配特征值 """
        matches = self.matcher.knnMatch(descriptors, k=2)
        return self.__class__.filter_matches(matches, ratio)

    def image_match(self, descriptors_1, descriptors_2, ratio=0.75):
        """ 与匹配器中的特征匹配获取最高匹配特征值 """
        try:
            matches = self.matcher.knnMatch(descriptors_1, trainDescriptors=descriptors_2, k=2)
        except TypeError as e:
            print (e)
            return []
        except cv2.error as e:
            print (e)
            return []

        return self.__class__.filter_matches(matches, ratio)

    def image_match_explore(self, key_points_1, descriptors_1, key_points_2, descriptors_2, ratio=0.75):
        """ 单独两张图的特征匹配 """
        matches = self.matcher.knnMatch(descriptors_1, trainDescriptors=descriptors_2, k=2)
        points_1, points_2, matches = ImageMatch.filter_matches(key_points_1, key_points_2, matches, ratio)
        return points_1, points_2, matches

    def add_matches_image(self, descriptors):
        """ 将图片特征添加到匹配器中 """
        try:
            self.matcher.add([descriptors])
        except TypeError:
            return False

        return True

    def clear_matches_image(self):
        """ 清空匹配器中的图片特征 """
        self.matcher.clear()

    @staticmethod
    def filter_matches(matches, ratio=0.75):
        """ 筛选有效特征 """
        return [m[0] for m in matches if len(m) == 2 and m[0].distance < m[1].distance * ratio]


class FeatureBasedImageMatchException(Exception):
    pass


class ImageMatch(object):
    """ 图片匹配的工具类 """
    def __init__(self):
        self.file_dir = '/tmp/'
        self.file_ext = '.npy'

    @staticmethod
    def image_read(image_file=None, image_buffer=None):
        """ 载入图片 """
        # 载入灰度图,OpenCV的图片读取只能使用本地文件路径,换成PIL读取,最终结果会有细微差别,忽略
        if image_buffer:
            im = Image.open(StringIO(image_buffer))
        else:
            if not os.path.isfile(image_file):
                c = pycurl.Curl()
                c.setopt(pycurl.URL, image_file)
                buf = StringIO()
                c.setopt(pycurl.WRITEFUNCTION, buf.write)
                c.setopt(pycurl.TIMEOUT, 5)
                try:
                    c.perform()
                except pycurl.error as e:
                    print e
                    return False
                image_buffer = buf.getvalue()
                buf.close()
                c.close()
                if not image_buffer:
                    return False
                im = Image.open(StringIO(image_buffer))

            else:
                im = Image.open(image_file)

        image = numpy.asarray(im.convert('L'))

        return image

    @staticmethod
    def filter_matches(kp1, kp2, matches, ratio=0.75):
        """ 筛选匹配特征 """
        mkp1, mkp2 = [], []
        for m in matches:
            if len(m) == 2 and m[0].distance < m[1].distance * ratio:
                m = m[0]
                mkp1.append(kp1[m.queryIdx])
                mkp2.append(kp2[m.trainIdx])
        p1 = numpy.float32([kp.pt for kp in mkp1])
        p2 = numpy.float32([kp.pt for kp in mkp2])
        kp_pairs = zip(mkp1, mkp2)
        return p1, p2, list(kp_pairs)

    @staticmethod
    def explore_match(win, img1, img2, kp_pairs, status=None, H=None):
        """ 预览匹配到的特征 """
        h1, w1 = img1.shape[:2]
        h2, w2 = img2.shape[:2]
        vis = numpy.zeros((max(h1, h2), w1 + w2), numpy.uint8)
        vis[:h1, :w1] = img1
        vis[:h2, w1:w1 + w2] = img2
        vis = cv2.cvtColor(vis, cv2.COLOR_GRAY2BGR)

        if H is not None:
            corners = numpy.float32([[0, 0], [w1, 0], [w1, h1], [0, h1]])
            corners = numpy.int32(cv2.perspectiveTransform(corners.reshape(1, -1, 2), H).reshape(-1, 2) + (w1, 0))
            cv2.polylines(vis, [corners], True, (255, 255, 255))

        if status is None:
            status = numpy.ones(len(kp_pairs), numpy.bool_)
        p1, p2 = [], []  # python 2 / python 3 change of zip unpacking
        for kpp in kp_pairs:
            p1.append(numpy.int32(kpp[0].pt))
            p2.append(numpy.int32(numpy.array(kpp[1].pt) + [w1, 0]))

        green = (0, 255, 0)
        red = (0, 0, 255)
        white = (255, 255, 255)
        kp_color = (51, 103, 236)
        for (x1, y1), (x2, y2), inlier in zip(p1, p2, status):
            if inlier:
                col = green
                cv2.circle(vis, (x1, y1), 2, col, -1)
                cv2.circle(vis, (x2, y2), 2, col, -1)
            else:
                col = red
                r = 2
                thickness = 3
                cv2.line(vis, (x1 - r, y1 - r), (x1 + r, y1 + r), col, thickness)
                cv2.line(vis, (x1 - r, y1 + r), (x1 + r, y1 - r), col, thickness)
                cv2.line(vis, (x2 - r, y2 - r), (x2 + r, y2 + r), col, thickness)
                cv2.line(vis, (x2 - r, y2 + r), (x2 + r, y2 - r), col, thickness)
        vis0 = vis.copy()
        for (x1, y1), (x2, y2), inlier in zip(p1, p2, status):
            if inlier:
                cv2.line(vis, (x1, y1), (x2, y2), green)

        cv2.imshow(win, vis)
        cv2.waitKey()
        cv2.destroyWindow(win)
        return vis

    def feature_save(self, picture_md5, feature):
        """ 保存特征 """
        filename = self.get_feature_path(picture_md5)
        return numpy.save(filename, feature)

    def feature_load(self, picture_md5):
        """ 加载特征 """
        filename = self.get_feature_path(picture_md5)
        if not filename.endswith('.npy'):
            filename += '.npy'
        return numpy.load(filename)

    def get_feature_path(self, picture_md5):
        """ 获取特征所在文件 """
        return self.file_dir + picture_md5 + self.file_ext

    def feature_exist(self, picture_md5):
        """ 检测特征文件是否存在 """
        filename = self.get_feature_path(picture_md5)
        if not filename.endswith('.npy'):
            filename += '.npy'
        return os.path.isfile(filename)

    def feature_remove(self, picture_md5):
        """ 删除特征文件 """
        filename = self.get_feature_path(picture_md5)
        if not filename.endswith('.npy'):
            filename += '.npy'
        return os.remove(filename)


class timer:
    """ 耗时计时器 """
    def __init__(self, func=time.time):
        self.elapsed = 0.0
        self._func = func
        self._start = None

    def start(self):
        if self._start is not None:
            raise RuntimeError('Already started')
        self._start = self._func()

    def stop(self):
        if self._start is None:
            raise RuntimeError('Not started')
        end = self._func()
        self.elapsed += end - self._start
        self._start = None

    def reset(self):
        self.elapsed = 0.0

    @property
    def running(self):
        return self._start is not None

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, *args):
        self.stop()

if __name__ == '__main__':
    match_object = FeatureBasedImageMatch()
    # 样本图片
    images = [
        'http://image-qzone.mamaquan.mama.cn/upload/2016/05/03/thumb_w196_bf6022f770f3b5f1eecc_w300X400_w192X256.jpg',
    ]
    # 待检测图片
    match_image = 'http://image-qzone.mamaquan.mama.cn/upload/2016/05/03/thumb_w196_4ab517ac123cc9abffbe_w300X400_w192X256.jpg'
    match_image_object = ImageMatch.image_read(match_image)
    match_key_point, match_descriptors = match_object.detect_and_compute(image=match_image_object)
    image_object = dict()

    # 单独两张图的特征匹配
    x = images[0]
    image_object[x] = dict()
    image_object[x]['image'] = ImageMatch.image_read(x)
    with timer() as t:
        image_object[x]['key_points'], image_object[x]['descriptors'] = match_object.detect_and_compute(
            image=image_object[x]['image'])
    print('Finish detect and compute image feature:%d and cost%f' % (len(image_object[x]['key_points']), t.elapsed))
    with timer() as t:
        points_1, points_2, matches = match_object.image_match_explore(match_key_point, match_descriptors,
                                                                       image_object[x]['key_points'],
                                                                       image_object[x]['descriptors'])
    print('Finish image match,match image feature:%d and cost%f' % (len(matches), t.elapsed))
    ImageMatch.explore_match('feature_match', match_image_object, image_object[x]['image'], matches)

    # 多张图片的特征匹配最大值
    for x in images:
        image_object[x] = dict()
        image_object[x]['image'] = ImageMatch.image_read(x)
        with timer() as t:
            image_object[x]['key_points'], image_object[x]['descriptors'] = match_object.detect_and_compute(image=image_object[x]['image'])
        print('Finish detect and compute image feature:%d and cost%f' % (len(image_object[x]['key_points']), t.elapsed))
        with timer() as t:
            match_object.add_matches_image(image_object[x]['descriptors'])
        print('Finish add matches image and cost%f' % t.elapsed)

    with timer() as t:
        matches = match_object.many_image_match(match_descriptors)
    print('Finish image match,match image feature:%d and cost%f' % (len(matches), t.elapsed))

你也可以在 Github上找到它。