返回文章列表

vector

2 min read
PYTHON
左右滑动查看完整代码
from typing import List, Optional
from langchain_chroma import Chroma
from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_core.documents import Document
from loguru import logger

from config import Config
from embedder import EmbeddingModel


class VectorStoreManager:
    """向量数据库管理器"""
    def __init__(self, config: Config, collection_name: str = "rag_documents"):
        self.collection_name = collection_name
        self.config = config 
        self.persist_directory = self.config.chroma_persist_dir
        self._embedding_model = EmbeddingModel(self.config)
        self._vectorstore: Optional[Chroma] = None
        logger.info(
            f"向量存储管理器初始化: persist_directory={self.persist_directory}, "
            f"collection={self.collection_name}"
        )

    @property
    def vectorstore(self) -> Chroma:
        """获取向量存储实例"""
        if self._vectorstore is None:
            self._vectorstore = self._create_vectorstore()
        return self._vectorstore

    def _create_vectorstore(self) -> Chroma:
        chroma_db = Chroma(
            collection_name=self.collection_name,
            embedding_function=self._embedding_model.embeddings,
            persist_directory=self.persist_directory,
        )
       
        logger.info(f"向量存储 {self.collection_name},已创建新实例")
        return chroma_db

    def add_documents(self, documents: List[Document]) -> None:
        """将文档写入向量库。"""
        if not documents:
            logger.warning("没有文档可写入向量存储")
            return

        # 过滤 metadata 中包含复杂对象的文档,避免 Chroma 存储失败
        valid_docs = filter_complex_metadata(documents)
        if not valid_docs:
            logger.warning("没有有效文档可写入向量存储")
            return

        self.vectorstore.add_documents(valid_docs)
        logger.info(f"已写入 {len(valid_docs)} 条文档到向量存储 {self.collection_name}")

    def similarity_search(self, query: str, top_k: int = 5) -> List[Document]:
        if self._vectorstore is None or self.vectorstore._collection.count() == 0:
            logger.warning("向量存储为空,无法执行相似度搜索")
            return []
        results = self.vectorstore.similarity_search(query, k=self.config.top_k or top_k)
        logger.info(f"相似度搜索完成: query='{query}', results={len(results)}")
        return results

    def similarity_search_with_score(
        self, query: str, top_k: int = 5
    ) -> List[tuple[Document, float]]:
        if self._vectorstore is None or self.vectorstore._collection.count() == 0:
            logger.warning("向量存储为空,无法执行相似度搜索")
            return []
        results = self.vectorstore.similarity_search_with_score(
            query, k=self.config.top_k or top_k
        )
        logger.info(f"相似度搜索完成: query='{query}', results={len(results)}")
        return results

    # 返回给 LangChain 使用的 retriever 接口
    def as_retriever(self, **kwargs):
        search_kwargs = kwargs.pop("search_kwargs", {})
        if "k" not in search_kwargs:
            search_kwargs["k"] = self.config.top_k

        return self.vectorstore.as_retriever(search_kwargs=search_kwargs, **kwargs)