随着AI技术的发展,越来越多的应用场景涉及到海量数据的处理和检索。然而,仅靠传统的检索算法往往不能满足高质量的搜索结果需求。本文将介绍如何使用 RankGPT 来对检索结果进行重排序,以提升搜索质量。

RankGPT 的基本原理

RankGPT 是一种基于大模型(如 ChatGPT 或 GPT-4 等)的零样本列表式段落重排序技术。它通过生成排列和滑动窗口策略,来高效地对段落进行重排序。

示例:使用 RankGPT 对梵高维基百科页面进行重排序

我们以梵高的维基百科页面为例,展示如何使用 RankGPT 对检索结果进行重排序。

环境准备

首先,我们需要安装必要的依赖包:

%pip install llama-index-postprocessor-rankgpt-rerank
%pip install llama-index-llms-huggingface
%pip install llama-index-llms-openai
%pip install llama-index-llms-ollama

加载数据并构建索引

我们将梵高的维基百科页面下载下来,并构建向量存储索引:

import nest_asyncio
nest_asyncio.apply()

import logging
import sys
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.llms.openai import OpenAI
import os
from pathlib import Path
import requests

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

# 设置 OpenAI API 密钥
OPENAI_API_TOKEN = "your_openai_api_key"
os.environ["OPENAI_API_KEY"] = OPENAI_API_TOKEN

# 下载梵高维基百科页面
wiki_titles = ["Vincent van Gogh"]
data_path = Path("data_wiki")
for title in wiki_titles:
    response = requests.get(
        "https://en.wikipedia.org/w/api.php",
        params={"action": "query", "format": "json", "titles": title, "prop": "extracts", "explaintext": True},
    ).json()
    page = next(iter(response["query"]["pages"].values()))
    wiki_text = page["extract"]

    if not data_path.exists():
        Path.mkdir(data_path)

    with open(data_path / f"{title}.txt", "w") as fp:
        fp.write(wiki_text)

# 加载文档并构建索引
documents = SimpleDirectoryReader("./data_wiki/").load_data()
index = VectorStoreIndex.from_documents(documents)

设置检索器和重排序器

我们将使用 OpenAI 的 GPT-3.5-turbo 模型来进行重排序,并展示没有重排序和使用 RankGPT 进行重排序的效果对比:

from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core import QueryBundle
from llama_index.postprocessor.rankgpt_rerank import RankGPTRerank
import pandas as pd
from IPython.display import display, HTML

def get_retrieved_nodes(query_str, vector_top_k=10, reranker_top_n=3, with_reranker=False):
    query_bundle = QueryBundle(query_str)
    retriever = VectorIndexRetriever(index=index, similarity_top_k=vector_top_k)
    retrieved_nodes = retriever.retrieve(query_bundle)

    if with_reranker:
        reranker = RankGPTRerank(
            llm=OpenAI(model="gpt-3.5-turbo-16k", temperature=0.0, api_key=OPENAI_API_TOKEN),
            top_n=reranker_top_n,
            verbose=True,
        )
        retrieved_nodes = reranker.postprocess_nodes(retrieved_nodes, query_bundle)

    return retrieved_nodes

def pretty_print(df):
    return display(HTML(df.to_html().replace("\\n", "<br>")))

def visualize_retrieved_nodes(nodes):
    result_dicts = []
    for node in nodes:
        result_dict = {"Score": node.score, "Text": node.node.get_text()}
        result_dicts.append(result_dict)
    pretty_print(pd.DataFrame(result_dicts))

# 不使用重排序器进行检索
new_nodes = get_retrieved_nodes("Which date did Paul Gauguin arrive in Arles?", vector_top_k=3, with_reranker=False)
visualize_retrieved_nodes(new_nodes)

# 使用 RankGPT 进行重排序
new_nodes = get_retrieved_nodes("Which date did Paul Gauguin arrive in Arles?", vector_top_k=10, reranker_top_n=3, with_reranker=True)
visualize_retrieved_nodes(new_nodes)

可能遇到的错误

  1. API 密钥错误:确保已正确设置 OpenAI API 密钥。如果密钥无效或过期,会导致 API 调用失败。
  2. 网络问题:在调用外部 API 时,网络问题可能导致请求失败。建议检查网络连接并重试。
  3. 数据加载错误:确保维基百科页面已正确下载并保存。如果文件路径错误或文件损坏,会导致数据加载失败。

如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!

参考资料:

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐