Faiss

Faiss(Facebook AI Similarity Search)是一个用于高效相似度搜索的库,特别适用于大规模的高维向量数据。它主要用于近似最近邻搜索(ANN),广泛应用于推荐系统、图像和文本相似度搜索等领域。

下面是如何使用 Faiss 的基本教程,包括安装、索引构建、插入数据和执行搜索的步骤。

1. 安装 Faiss

你可以通过以下命令安装 Faiss:

a. 安装 CPU 版本
pip install faiss-cpu
b. 安装 GPU 版本(如果你想使用 GPU 加速)
pip install faiss-gpu

2. 创建一个 Faiss 索引

首先,我们需要创建一个 Faiss 索引。Faiss 支持多种索引类型,不同的索引类型适用于不同的应用场景。最常见的索引类型包括:

  • Flat: 精确的暴力搜索(适合小规模数据)。
  • IVF (Inverted File): 更高效的近似搜索,适合中到大规模数据。
  • HNSW (Hierarchical Navigable Small World): 高效的近似搜索,适合大规模数据。
2.1 创建一个 Flat 索引
import faiss
import numpy as np

# 创建一个随机的 10000 条 128 维向量
d = 128  # 向量维度
nb = 10000  # 向量数量
np.random.seed(1234)
xb = np.random.random((nb, d)).astype('float32')

# 创建一个 Faiss 索引
index = faiss.IndexFlatL2(d)  # 使用 L2 距离度量创建 Flat 索引

# 向索引中添加向量
index.add(xb)

print("Number of vectors in the index:", index.ntotal)
2.2 创建一个 IVF 索引
# 创建一个 IVF 索引(适用于更大规模数据集)
nlist = 100  # 索引分桶的数量
quantizer = faiss.IndexFlatL2(d)  # 使用 Flat 索引作为量化器
index_ivf = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)

# 训练索引(需要一个大部分数据的子集来训练索引)
index_ivf.train(xb)

# 向索引中添加数据
index_ivf.add(xb)

print("Number of vectors in IVF index:", index_ivf.ntotal)
2.3 创建一个 HNSW 索引
# 创建一个 HNSW 索引
index_hnsw = faiss.IndexHNSWFlat(d, 32)  # 32 是 HNSW 图的连接数

# 向索引中添加向量
index_hnsw.add(xb)

print("Number of vectors in HNSW index:", index_hnsw.ntotal)

3. 执行相似度搜索

假设你有一个查询向量,并且你想查找与之最相似的向量。

3.1 搜索最近邻向量
# 创建一个随机查询向量
xq = np.random.random((5, d)).astype('float32')  # 5 个查询向量

# 执行最近邻搜索,返回最相似的 5 个向量
k = 5  # 查找最相似的 5 个向量
D, I = index.search(xq, k)  # D 是距离,I 是索引

print("Distances of the nearest neighbors:\n", D)
print("Indices of the nearest neighbors:\n", I)

D 是距离矩阵,I 是索引矩阵,分别表示每个查询向量与最相似的 k 个向量之间的距离和对应的索引。

3.2 搜索时的优化(IVF 或 HNSW)

在使用 IVF 或 HNSW 时,你可能需要指定更多的搜索参数来优化检索效果,比如 nprobe(在 IVF 中控制搜索的分桶数量),或 efSearch(在 HNSW 中控制搜索效率)。

# 在 IVF 索引中执行搜索
index_ivf.nprobe = 10  # 搜索时考虑的桶的数量
D_ivf, I_ivf = index_ivf.search(xq, k)

print("Distances of the nearest neighbors (IVF):\n", D_ivf)
print("Indices of the nearest neighbors (IVF):\n", I_ivf)
# 在 HNSW 索引中执行搜索
index_hnsw.hnsw_efSearch = 32  # 增加搜索时的图搜索精度
D_hnsw, I_hnsw = index_hnsw.search(xq, k)

print("Distances of the nearest neighbors (HNSW):\n", D_hnsw)
print("Indices of the nearest neighbors (HNSW):\n", I_hnsw)

4. 保存和加载索引

你可以将索引保存到磁盘,并在之后加载它以节省时间。

4.1 保存索引到文件
faiss.write_index(index, "index_flat.index")
4.2 从文件加载索引
index_loaded = faiss.read_index("index_flat.index")
print("Loaded index has", index_loaded.ntotal, "vectors.")

5. 使用 GPU 加速(可选)

如果你有 GPU,Faiss 支持在 GPU 上加速向量搜索。首先,你需要安装 GPU 版本的 Faiss,然后将索引移动到 GPU:

import faiss
res = faiss.StandardGpuResources()  # GPU 资源管理
gpu_index = faiss.index_cpu_to_gpu(res, 0, index)  # 将索引移动到 GPU

# 搜索时使用 GPU 加速
D_gpu, I_gpu = gpu_index.search(xq, k)

总结

  • 创建索引:使用不同的索引类型(IndexFlatL2, IndexIVFFlat, IndexHNSWFlat 等)来创建 Faiss 索引。
  • 插入数据:使用 .add() 方法将向量插入到索引中。
  • 执行搜索:使用 .search() 方法进行最近邻搜索。
  • GPU 加速:如果需要更高的性能,可以使用 Faiss 的 GPU 版本来加速搜索。

如果你需要更多关于特定索引类型或高级功能的帮助,随时告诉我!

代码运行结果

import faiss
import numpy as np

# 创建一个随机的 10000 条 128 维向量
d = 128  # 向量维度
nb = 10000  # 向量数量
np.random.seed(1234)
xb = np.random.random((nb, d)).astype('float32')

# 创建一个 Faiss 索引
index = faiss.IndexFlatL2(d)  # 使用 L2 距离度量创建 Flat 索引

# 向索引中添加向量
index.add(xb)

print("Number of vectors in the index:", index.ntotal)

Number of vectors in the index: 10000
# 创建一个 IVF 索引(适用于更大规模数据集)
nlist = 100  # 索引分桶的数量
quantizer = faiss.IndexFlatL2(d)  # 使用 Flat 索引作为量化器
index_ivf = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)

# 训练索引(需要一个大部分数据的子集来训练索引)
index_ivf.train(xb)

# 向索引中添加数据
index_ivf.add(xb)

print("Number of vectors in IVF index:", index_ivf.ntotal)

Number of vectors in IVF index: 10000
# 创建一个 HNSW 索引
index_hnsw = faiss.IndexHNSWFlat(d, 32)  # 32 是 HNSW 图的连接数

# 向索引中添加向量
index_hnsw.add(xb)

print("Number of vectors in HNSW index:", index_hnsw.ntotal)

Number of vectors in HNSW index: 10000
# 创建一个随机查询向量
xq = np.random.random((5, d)).astype('float32')  # 5 个查询向量

# 执行最近邻搜索,返回最相似的 5 个向量
k = 5  # 查找最相似的 5 个向量
D, I = index.search(xq, k)  # D 是距离,I 是索引

print("Distances of the nearest neighbors:\n", D)
print("Indices of the nearest neighbors:\n", I)

Distances of the nearest neighbors:
 [[14.120588 14.428287 14.519493 14.678541 14.977838]
 [14.769967 15.277808 15.82649  15.894089 16.032442]
 [14.996597 15.646382 15.67143  15.70339  16.03984 ]
 [13.788441 13.872476 14.54991  14.625786 14.823553]
 [13.655876 13.713434 13.981209 14.274239 14.780527]]
Indices of the nearest neighbors:
 [[7548 5874 4044 1573 2826]
 [8619 2762 2862  417 6717]
 [6059 5596 1617  106 7259]
 [7167 6353 8302 4062 1648]
 [5531 9262 2933 5512 7722]]
# 在 IVF 索引中执行搜索
index_ivf.nprobe = 10  # 搜索时考虑的桶的数量
D_ivf, I_ivf = index_ivf.search(xq, k)

print("Distances of the nearest neighbors (IVF):\n", D_ivf)
print("Indices of the nearest neighbors (IVF):\n", I_ivf)

Distances of the nearest neighbors (IVF):
 [[14.120588 14.519493 15.13227  15.136275 15.145167]
 [15.277808 15.82649  15.894089 16.032442 16.066723]
 [15.646382 16.70518  17.00241  17.005293 17.031368]
 [13.788441 14.54991  14.858726 15.078858 15.11237 ]
 [14.780527 15.026012 15.086645 15.142347 15.270004]]
Indices of the nearest neighbors (IVF):
 [[7548 4044 7775 4722 2860]
 [2762 2862  417 6717 9150]
 [5596  886 2020 6283 5516]
 [7167 8302 9207  656 5963]
 [7722 7668 8303 1648  557]]
# 在 HNSW 索引中执行搜索
index_hnsw.hnsw_efSearch = 32  # 增加搜索时的图搜索精度
D_hnsw, I_hnsw = index_hnsw.search(xq, k)

print("Distances of the nearest neighbors (HNSW):\n", D_hnsw)
print("Indices of the nearest neighbors (HNSW):\n", I_hnsw)

Distances of the nearest neighbors (HNSW):
 [[14.120588  14.9778385 15.132271  15.145167  15.160655 ]
 [15.277808  16.032898  16.158434  16.185116  16.32803  ]
 [14.996596  15.646384  15.67143   15.70339   16.03984  ]
 [13.788441  14.549911  14.858726  14.890352  15.112368 ]
 [13.655877  14.274239  14.780525  15.016905  15.026011 ]]
Indices of the nearest neighbors (HNSW):
 [[7548 2826 7775 2860 9951]
 [2762 6644 7958 5217 4221]
 [6059 5596 1617  106 7259]
 [7167 8302 9207 9815 5963]
 [5531 5512 7722 5746 7668]]
faiss.write_index(index, "index_flat.index")
index_loaded = faiss.read_index("index_flat.index")
print("Loaded index has", index_loaded.ntotal, "vectors.")
Loaded index has 10000 vectors.
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐