返回文章列表

稍微复杂点的chain

4 min read
PYTHON
左右滑动查看完整代码
"""
    构建rag链
    查询重写
    混合检索
    重排序
"""

from typing import List, Optional, Dict, Any
from loguru import logger
from pydantic import SecretStr
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

from config import Config, get_config
from vector import VectorStoreManager
from query_rewriter import QueryRewriter
from hybrid_retriever import HybridRetriever
from reranker import Reranker
# 默认的 RAG 提示模板,包含了对模型能力的描述和回答要求
DEFAULT_RAG_PROMPT = """
你是一个能够进行深度阅读理解、总结和分析的中文助手。

请基于参考文档回答用户问题,但允许你进行:
- 归纳总结
- 风格分析
- 观点选择
- 句子评价
- 合理推断(必须基于文档给出的内容)

禁止编造超出文档不存在的事实信息,但允许在文中信息基础上做解释性、分析性和判断性的扩展。

如果文档完全无关,才回答:"根据提供的文档,我无法回答这个问题"。

【参考文档】
{context}

【用户问题】
{question}

请用中文回答:
"""
reranker_model: str = "bge-reranker-base"
class MutiFunctionalRAGChain:
    """
    RAG 问答链
    
    组装检索器、Prompt 和 LLM,实现端到端的问答功能。
    """
    def __init__(
        self,
        documents: Optional[List[Document]] = None,
        config: Optional[Config] = None,
        vector_store_manager: Optional[VectorStoreManager] = None,
        llm: Optional[ChatOpenAI] = None,
        prompt_template: Optional[str] = None,
    ):
        """
        初始化 RAGChain 实例
        
        Args:
            config: 配置对象,如果为 None 则从环境变量加载
            vector_store_manager: 向量存储管理器实例,如果为 None 则创建默认实例
            llm: 语言模型实例,如果为 None 则创建默认 ChatOpenAI 实例
            prompt_template: RAG 提示模板字符串,如果为 None 则使用默认模板
        """
        self.config = config or get_config()
        self.vector_store_manager = vector_store_manager or VectorStoreManager(self.config)
        self.llm = self._create_llm()
        self.prompt_template =ChatPromptTemplate.from_template(prompt_template or DEFAULT_RAG_PROMPT)
        self.query_rewriter = QueryRewriter(self.config)
        self.hybrid_retriever = HybridRetriever(
            documents=documents,
            vectorstore_manager=self.vector_store_manager,
            config=self.config,
            bm25_weight=0.5,
            vector_weight=0.5,
        )
        self._reranker = Reranker(
            model_name=reranker_model,
            config=self.config,
        )
    def _create_llm(self) -> ChatOpenAI:
        """创建 ChatOpenAI 实例"""
        params={
            "model": self.config.model_name,
            "temperature": 0.2,
            "api_key": SecretStr(self.config.openai_api_key),
        }
        if self.config.openai_base_url:
            params["base_url"] = self.config.openai_base_url
        return ChatOpenAI(
            **params
        )
    def _format_docs(self, docs: List[Document]) -> str:
        """
        格式化检索到的文档
        
        Args:
            docs: 文档列表
            
        Returns:
            str: 格式化后的文档文本
        """
        formatted = []
        for i, doc in enumerate(docs, 1):
            source = doc.metadata.get("file_name", "未知来源")
            content = doc.page_content.strip()
            formatted.append(f"[文档 {i}] (来源: {source})\n{content}")
        # print("formatted__________",formatted)
        return "\n\n---\n\n".join(formatted)
    def _retrieve_docs(self, query: list[str]) -> List[Document]:
        """检索相关文档"""
        resdocs = []
        seen_doc_keys = set()  # 用于去重,记录已经添加过的文档键
        for q in query:
            # 直接进行混合搜索 检索器里面虽然也有去重逻辑,但是那是针对一个查询检索出来的文档
            # 当时多query时 外部所有的检索结果也需要去重
            # 增加每个查询的检索数量,确保覆盖更多内容
            docs = self.hybrid_retriever.search(q, k=10)
            for doc in docs:
                # 去重是因为有多query时可能会检索到重复的文档,hash(doc.page_content) 作为文档的唯一标识
                doc_key = hash(doc.page_content)
                if doc_key not in seen_doc_keys:
                    seen_doc_keys.add(doc_key)
                    resdocs.append(doc)
        return resdocs

    def _expand_parent_chunks(
        self,
        docs: List[Document],
        window_size: int = 4,
    ) -> List[Document]:
        """
        根据命中的子 chunk,补充同源文件的相邻 chunk 作为父级上下文。

        当前项目没有单独保存“父文档”对象,所以这里用“相邻 chunk”模拟父级上下文:
        - 子 chunk:检索和重排序真正命中的小片段;
        - 父级上下文:只取命中片段前后 window_size 个 chunk;
        - 作用:避免答案需要前后文拼接时,被 reranker 截断掉关键小片段。

        注意:这里不会因为“同源文件”相同就把整个文件的所有 chunk 都塞进上下文。
        即使一个文件有几百个 chunk,也只会围绕命中的 chunk 做有限窗口扩展。
        """
        # HybridRetriever 里保存的是参与 BM25 索引的全部 chunk。
        # 用它可以根据 source_file + chunk_index 找回命中 chunk 的邻居。
        all_docs = getattr(self.hybrid_retriever, "_documents", [])
        if not docs or not all_docs:
            return docs

        # 建立索引:
        # {
        #   "来源文件路径": {
        #       0: chunk0,
        #       1: chunk1,
        #   }
        # }
        # 这样后面拿到某个命中 chunk 时,可以快速找到同文件的前后 chunk。
        source_chunk_map: dict[str, dict[int, Document]] = {}
        for doc in all_docs:
            source = doc.metadata.get("source_file") or doc.metadata.get("file_name")
            chunk_index = doc.metadata.get("chunk_index")
            if source is None or chunk_index is None:
                continue
            source_chunk_map.setdefault(str(source), {})[int(chunk_index)] = doc

        expanded_docs = []
        seen_doc_keys = set()
        for doc in docs:
            source = doc.metadata.get("source_file") or doc.metadata.get("file_name")
            chunk_index = doc.metadata.get("chunk_index")

            # 如果某个文档没有 chunk_index,说明它不是标准切分出来的 chunk。
            # 这种情况没法找相邻 chunk,就保留它自身,避免误删。
            if source is None or chunk_index is None:
                doc_key = hash(doc.page_content)
                if doc_key not in seen_doc_keys:
                    seen_doc_keys.add(doc_key)
                    expanded_docs.append(doc)
                continue

            chunks = source_chunk_map.get(str(source), {})

            # 只取命中 chunk 附近的有限窗口,避免同源文件很大时把几百个 chunk 都塞进去。
            # window_size=4 表示命中 chunk_index=15 时,最多额外带上 11~19 这些邻居。
            start_index = int(chunk_index) - window_size
            end_index = int(chunk_index) + window_size
            candidate_indices = range(start_index, end_index + 1)

            for current_index in candidate_indices:
                parent_doc = chunks.get(current_index)
                if parent_doc is None:
                    continue

                # 同一个 chunk 可能被多个命中 chunk 的窗口覆盖,这里统一去重。
                doc_key = (str(source), current_index)
                if doc_key in seen_doc_keys:
                    continue
                seen_doc_keys.add(doc_key)
                expanded_docs.append(parent_doc)

        return expanded_docs
    
    def ask(self, question: str) -> str:
        """处理用户问题,返回答案"""
        logger.info(f"收到问题:{question}")
        querys = self.query_rewriter.generate_multi_queries(question)
        # 将原始问题也加入检索查询列表,提高召回率
        querys_with_original = [question] + querys
        logger.info(f"生成的多路查询(含原始问题):{querys_with_original}")
        
        # docs = self.vector_store_manager.similarity_search(question)
        docs = self._retrieve_docs(querys_with_original)
        # 先让 reranker 从粗检索结果里选出最相关的子 chunk。
        reranked = self._reranker.rerank(question, docs, top_k=self.config.top_k)
        docs = [doc for doc, _ in reranked]
        # 再根据命中的子 chunk 扩展前后文,避免关键上下文被 top_k 截掉。
        docs = self._expand_parent_chunks(docs)
        context = self._format_docs(docs)
        prompt = self.prompt_template.format(context=context, question=question)
        response = self.llm.invoke(prompt)
        answer = StrOutputParser().invoke(response)
        return answer
    
    def ask_with_source(self, question: str) -> Dict[str, Any]:
        """处理用户问题,返回答案和参考来源"""
        logger.info(f"收到问题:{question}")
        querys = self.query_rewriter.generate_multi_queries(question)
        # 将原始问题也加入检索查询列表,提高召回率
        querys_with_original = [question] + querys
        logger.info(f"生成的多路查询(含原始问题):{querys_with_original}")
        
        # 检索文档
        docs = self._retrieve_docs(querys_with_original)
        logger.info(f"检索到 {len(docs)} 个文档")
        
        # 先让 reranker 只保留最相关的子 chunk。
        reranked = self._reranker.rerank(question, docs, top_k=self.config.top_k)
        docs = [doc for doc, _ in reranked]
        logger.info(f"重排序后保留 {len(docs)} 个子 chunk")
        # 再补充每个命中子 chunk 的前后相邻 chunk,形成更完整的父级上下文。
        docs = self._expand_parent_chunks(docs)
        logger.info(f"父子 chunk 扩展后保留 {len(docs)} 个文档")
        
        # 格式化上下文
        context = self._format_docs(docs)
        prompt = self.prompt_template.format(context=context, question=question)
        
        # 生成答案
        response = self.llm.invoke(prompt)
        answer = StrOutputParser().invoke(response)
        
        # 构建来源信息
        sources = [
            {
                "content": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content,
                "source": doc.metadata.get("file_name", "未知"),
                "metadata": doc.metadata
            }
            for doc in docs
        ]
        
        return {
            "answer": answer,
            "sources": sources,
            "question": question
        }