提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录


前言

提示:这里可以添加本文要记录的大概内容:
我在局域网内安装了milvus服务,现在需要把一个文件夹里所有的markdown文档都向量化到milvus数据库中。
所用的切分方法是langchain中自带的markdown按照标题切分,且带有标题的metadata,非常清晰。
但是用langchain提供的milvus向量化的时候遇到很多麻烦的错误也很难追寻。最后决定还是用自带的pymilvus做,效果相当好。

在这里插入图片描述


代码

from config import envConfig, config
import os
from langchain_milvus import Milvus
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_ollama import OllamaEmbeddings
from langchain.retrievers import EnsembleRetriever
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection
import logging
from pymilvus import utility

logger = logging.getLogger('vectorize')

MAX_CHUNK_SIZE = 400
CHUNK_OVERLAP = 100
embedding = OllamaEmbeddings(
    model=envConfig.OLLAMA_EMBEDDING, # 记得改embedding配置
    base_url=envConfig.OLLAMA_URL
)


def get_all_md():  # 按照自己的需要修改文件夹
    files = os.listdir("./markdown_files")
    return ['./markdown_files/' + f for f in files if f.endswith(".md")]


def split_by_md_header(mdDocs):
    """
    Split with metadata like "Metadata (("Header_1", "xxx"), ("Header_2", "xxxx")) page_content="xxx..."
    :param mdDocs:
    :return:
    """
    headers_to_split_on = [("#", "Header_1"), ("##", "Header_2"), ("###", "Header_3")]
    md_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on, return_each_line=True)
    md_header_splits = md_splitter.split_text(mdDocs)
    final_splits = []
    for _, i in enumerate(md_header_splits):
        print(f"Index: {_}, Metadata: {i.metadata} Content: {i.page_content}")
        print("---------------------")
    return md_header_splits


def get_vector_store(collectionName: str = 'LangChainCollection', is_local=True):
    if is_local:
        milvusConn = {"uri": envConfig.MILVUS_LOCAL_URI}
    else:
        milvusConn = {  # 这里可以选择zilliz的连接方式
            "uri": envConfig.BACKUP_MILVUS_URI,
            "user": envConfig.BACKUP_MILVUS_USERNAME,
            "password": envConfig.BACKUP_MILVUS_PASSWORD,
            "secure": True,
        }
    logger.info(f"Milvus URI: {milvusConn}")

    # 连接到 Milvus
    connections.connect(**milvusConn)

    # 检查集合是否存在,如果存在则删除
    utility.drop_collection(collectionName)

    # 定义集合的字段
    fields = [
        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False),
        FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=4096),
        FieldSchema(name="header_1", dtype=DataType.VARCHAR, max_length=256, default_value='EMPTY'),
        FieldSchema(name="header_2", dtype=DataType.VARCHAR, max_length=256, default_value='EMPTY'),
        FieldSchema(name="header_3", dtype=DataType.VARCHAR, max_length=256, default_value='EMPTY'),
        FieldSchema(name="metadata", dtype=DataType.JSON),
        FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=1024)
    ]

    # 创建集合
    schema = CollectionSchema(fields, description="LangChain Collection")
    collection = Collection(name=collectionName, schema=schema)
    index_params = {
        "index_type": "IVF_FLAT",
        "metric_type": "L2",
        "params": {"nlist": 128}
    }
    collection.create_index(field_name="vector", index_params=index_params)

    # 加载集合
    collection.load()

    # 返回集合对象
    return collection


def vectorize_all(drop_old=True):
    print("Starting to vectorize all markdown files")
    embeddings = OllamaEmbeddings(
        model=envConfig.OLLAMA_EMBEDDING,
        base_url=envConfig.OLLAMA_URL
    )
    print(f"Embedding Model: {embeddings.model}")
    files = get_all_md()
    all_splits = []
    for i, file in enumerate(files):
        with open(file, 'r', encoding='utf-8') as f:
            mdDocs = f.read()
        md_header_splits = split_by_md_header(mdDocs)
        all_splits.extend(md_header_splits)

    # 生成唯一的 id
    ids = [j for j in range(len(all_splits))]

    # 创建文档对象并添加所有字段
    documents = []
    for id, split in zip(ids, all_splits):
        metadata = split.metadata
        header_1 = metadata.get("Header_1", "")
        header_2 = metadata.get("Header_2", "")
        header_3 = metadata.get("Header_3", "")
        text_embedding = embeddings.embed_query(split.page_content)
        documents.append({
            "id": id,
            "text": split.page_content,
            "header_1": header_1,
            "header_2": header_2,
            "header_3": header_3,
            "metadata": metadata,
            "vector": text_embedding
        })

    vector_store = get_vector_store(collectionName=config.APP_NAME.replace(' ', '_'), is_local=False)

    # 插入数据
    vector_store.insert(documents)
    print(f"Vectorized Done to collection: {config.APP_NAME.replace(' ', '_')}")


# vector = get_vector_store(embedding, "Intention_Instruction")
# vector.as_retriever().invoke("指令类型")
def test_vectorize():
    # milvusConn = {"uri": envConfig.MILVUS_LOCAL_URI}
    milvusConn = {
        "uri": envConfig.BACKUP_MILVUS_URI,
        "user": envConfig.BACKUP_MILVUS_USERNAME,
        "password": envConfig.BACKUP_MILVUS_PASSWORD,
        "secure": True,
    }
    embeddings = OllamaEmbeddings(
        model=envConfig.OLLAMA_EMBEDDING,
        base_url=envConfig.OLLAMA_URL
    )
    retriever = Milvus(collection_name=config.APP_NAME.replace(' ', '_'), connection_args=milvusConn,
                    embedding_function=embeddings).as_retriever(search_kwargs={"k": 6})
    ensemble_retriever = EnsembleRetriever(retrievers=[retriever])
    is_over = False
    while not is_over:
        query = input("请输入查询:")
        if query == "exit":
            is_over = True
            continue
        result = ensemble_retriever.get_relevant_documents(query)
        for i, ele in enumerate(result):
            print(f"{i} - Found Doc: {ele.metadata}, {ele.page_content}")


vectorize_all()
test_vectorize()

测试

运行test_vectorize()方法能持续查询文档搜索效果,直到输入exit退出。

注意

我用的embeddings模型是bge-m3:latest,如果是其他模型有可能会出现分段段落太长而导致出错,如果分段段落太长,应该需要自己写切分文本的逻辑噢,记得加上metadata。

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐