使用 Milvus 搜索图像是一种高效且准确的方法,它基于向量数据库技术,将图像转化为向量表示,然后在向量空间中进行相似度搜索。以下是使用 Milvus 搜索图像的一般步骤:

一、准备工作

  1. 安装 Milvus‌:

    • 可以从 Milvus 的官方网站下载并按照说明进行安装。
    • 确保系统满足 Milvus 的最低要求。
  2. 安装依赖库‌:

    • 通常需要安装 Python 3 和 pip,以及用于图像处理的库,如 OpenCV、Pillow 等。
    • 如果使用深度学习模型进行特征提取,还需要安装相应的深度学习框架,如 PyTorch 或 TensorFlow。
  3. 准备数据集‌:

    • 可以使用现有的图像数据集,如 PASCAL VOC、COCO 等。
    • 也可以自己制作数据集,确保图像质量和标注的准确性。

二、图像特征提取

  1. 选择特征提取模型‌:

    • 常用的深度学习模型包括 ResNet、VGG、CLIP 等。
    • 这些模型可以将图像转化为固定长度的向量表示,捕捉图像的主要特征。
  2. 进行特征提取‌:

    • 使用选定的模型对图像数据集进行特征提取。
    • 提取到的向量将用于后续的搜索和比较。

三、向量存储与索引

  1. 存储向量‌:

    • 将提取到的图像特征向量存储到 Milvus 数据库中。
    • Milvus 会自动为每个向量生成一个唯一的 ID,并与原始图像数据建立关联。
  2. 创建索引‌:

    • 为了提高搜索效率,可以为向量数据创建索引。
    • Milvus 支持多种索引类型,如 Flat Index、Annoy Index、Faiss Index 等。
    • 根据实际需求选择合适的索引类型,并设置相应的参数。

四、图像搜索

  1. 上传查询图像‌:

    • 当用户上传一张新的图像进行查询时,系统需要将其转化为向量表示。
    • 这通常通过与之前相同的特征提取模型实现。
  2. 执行搜索‌:

    • 使用 Milvus 的 API 或 Python SDK 执行相似度搜索。
    • 设置搜索的阈值、返回结果的数量等参数。
    • Milvus 会在其存储的向量集合中查找与查询向量最相似的向量,并返回相应的 ID。
  3. 展示结果‌:

    • 根据返回的向量 ID,在原始图像数据集中查找对应的图像。
    • 将搜索结果展示给用户,通常按照相似度排序。

五、优化与调整

  1. 优化特征提取模型‌:

    • 尝试不同的深度学习模型或调整模型参数,以提高特征提取的准确性和效率。
  2. 调整索引参数‌:

    • 根据实际需求调整索引的类型和参数,以平衡搜索速度和召回率。
  3. 数据质量与标注‌:

    • 确保图像数据的质量和标注的准确性,以提高搜索结果的可靠性。

数据集准备

首先,我们需要加载数据集并解压缩,以便进一步处理。

$ wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
$ unzip -q -o reverse_image_search.zip

前提条件

要运行本笔记本,您需要安装以下依赖项:

  • pymilvus>=2.4.2
  • timm
  • 火炬
  • numpy
  • sklearn
  • 枕头

要运行 Colab,我们提供了安装必要依赖项的便捷命令。

$ pip install pymilvus --upgrade
$ pip install timm

如果您使用的是 Google Colab,要启用刚刚安装的依赖项,可能需要重启运行时。(点击屏幕上方的 "Runtime(运行时)"菜单,从下拉菜单中选择 "Restart session(重新启动会话)")。

定义特征提取器

然后,我们需要定义一个特征提取器,使用 timm 的 ResNet-34 模型从图像中提取嵌入信息。

import torch
from PIL import Image
import timm
from sklearn.preprocessing import normalize
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform


class FeatureExtractor:
    def __init__(self, modelname):
        # Load the pre-trained model
        self.model = timm.create_model(
            modelname, pretrained=True, num_classes=0, global_pool="avg"
        )
        self.model.eval()

        # Get the input size required by the model
        self.input_size = self.model.default_cfg["input_size"]

        config = resolve_data_config({}, model=modelname)
        # Get the preprocessing function provided by TIMM for the model
        self.preprocess = create_transform(**config)

    def __call__(self, imagepath):
        # Preprocess the input image
        input_image = Image.open(imagepath).convert("RGB")  # Convert to RGB if needed
        input_image = self.preprocess(input_image)

        # Convert the image to a PyTorch tensor and add a batch dimension
        input_tensor = input_image.unsqueeze(0)

        # Perform inference
        with torch.no_grad():
            output = self.model(input_tensor)

        # Extract the feature vector
        feature_vector = output.squeeze().numpy()

        return normalize(feature_vector.reshape(1, -1), norm="l2").flatten()

创建 Milvus Collections

然后,我们需要创建一个 Milvus Collections 来存储图像嵌入信息

from pymilvus import MilvusClient

# Set up a Milvus client
client = MilvusClient(uri="example.db")
# Create a collection in quick setup mode
if client.has_collection(collection_name="image_embeddings"):
    client.drop_collection(collection_name="image_embeddings")
client.create_collection(
    collection_name="image_embeddings",
    vector_field_name="vector",
    dimension=512,
    auto_id=True,
    enable_dynamic_field=True,
    metric_type="COSINE",
)

至于MilvusClient 的参数:

  • uri 设置为本地文件,如./milvus.db ,是最方便的方法,因为它会自动利用Milvus Lite将所有数据存储在此文件中。
  • 如果数据规模较大,可以在docker 或 kubernetes 上设置性能更强的 Milvus 服务器。在此设置中,请使用服务器 uri,例如http://localhost:19530 ,作为您的uri 。
  • 如果你想使用Zilliz Cloud(Milvus 的全托管云服务),请调整uri 和token ,它们与 Zilliz Cloud 中的公共端点和 Api 密钥相对应。

将嵌入数据插入 Milvus

我们将使用 ResNet34 模型提取每张图片的嵌入,并将训练集中的图片插入 Milvus。

import os

extractor = FeatureExtractor("resnet34")

root = "./train"
insert = True
if insert is True:
    for dirpath, foldername, filenames in os.walk(root):
        for filename in filenames:
            if filename.endswith(".JPEG"):
                filepath = dirpath + "/" + filename
                image_embedding = extractor(filepath)
                client.insert(
                    "image_embeddings",
                    {"vector": image_embedding, "filename": filepath},
                )
from IPython.display import display

query_image = "./test/Afghan_hound/n02088094_4261.JPEG"

results = client.search(
    "image_embeddings",
    data=[extractor(query_image)],
    output_fields=["filename"],
    search_params={"metric_type": "COSINE"},
)
images = []
for result in results:
    for hit in result[:10]:
        filename = hit["entity"]["filename"]
        img = Image.open(filename)
        img = img.resize((150, 150))
        images.append(img)

width = 150 * 5
height = 150 * 2
concatenated_image = Image.new("RGB", (width, height))

for idx, img in enumerate(images):
    x = idx % 5
    y = idx // 5
    concatenated_image.paste(img, (x * 150, y * 150))
display("query")
display(Image.open(query_image).resize((150, 150)))
display("results")
display(concatenated_image)
'query'

png

png

'results'

Results

结果

我们可以看到,大部分图片都与搜索图片属于同一类别,即阿富汗猎犬。这说明我们找到了与搜索图片相似的图片。

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐