返回文章列表

将之前的项目修改为langgraph

2 min read

langgraph 是在langchain上的拓展,那么之前langchain的项目也可以通过langgraph编排来实现整个的流程 虽然是langgraph编排的,但是我们流程依然没变,和之前langchain的一样

查询改写 -> 检索 -> 重排 -> 父子 chunk 扩展 -> 生成答案

PYTHON
左右滑动查看完整代码
"""
LangGraph 版本的 RAG 流程。

这个文件的目标不是替换 multi_functional_chain.py,而是把现有 RAG 步骤拆成图节点,
方便学习 LangGraph 的状态流转、节点拆分和后续条件分支扩展。
"""

from typing import Any, NotRequired, TypedDict

from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import END, START, StateGraph

from multi_functional_chain import MutiFunctionalRAGChain


class RAGState(TypedDict):
    """LangGraph 节点之间共享的状态。"""

    question: str # 用户的初始提问
    queries: NotRequired[list[str]] # 改写后的查询
    docs: NotRequired[list[Document]] # 召回的文档
    context: NotRequired[str] # 格式化后的文档
    answer: NotRequired[str] # 生成的答案
    sources: NotRequired[list[dict[str, Any]]] # 来源信息   


class RAGGraph:
    """把现有 MutiFunctionalRAGChain 包装成 LangGraph 图。"""

    def __init__(self, chain: MutiFunctionalRAGChain):
        self.chain = chain
        self.app = self._build_graph()

    def _build_graph(self):
        """构建图结构:查询改写 -> 检索 -> 重排 -> 父子 chunk 扩展 -> 生成答案。"""
        graph = StateGraph(RAGState)

        graph.add_node("rewrite_query", self._rewrite_query) # 查询改写节点
        graph.add_node("retrieve_docs", self._retrieve_docs) # 检索节点
        graph.add_node("rerank_docs", self._rerank_docs) # 重排节点
        graph.add_node("expand_parent_chunks", self._expand_parent_chunks) # 父子 chunk 扩展 节点
        graph.add_node("format_context", self._format_context) # 格式化文档
        graph.add_node("generate_answer", self._generate_answer) # 生成答案
        graph.add_node("build_sources", self._build_sources) # 构建来源信息

        graph.add_edge(START, "rewrite_query") # 起点(用户输入) -> 查询改写
        graph.add_edge("rewrite_query", "retrieve_docs") # 查询改写 -> 检索
        graph.add_edge("retrieve_docs", "rerank_docs") # 检索 -> 重排
        graph.add_edge("rerank_docs", "expand_parent_chunks") # 重排 -> 父子 chunk 扩展
        graph.add_edge("expand_parent_chunks", "format_context") # 父子 chunk 扩展 -> 格式化文档
        graph.add_edge("format_context", "generate_answer") # 格式化文档 -> 生成答案
        graph.add_edge("generate_answer", "build_sources") # 生成答案 -> 构建来源信息
        graph.add_edge("build_sources", END) # 构建来源信息 -> 终点(返回结果)

        return graph.compile() # 编译图结构

    def _rewrite_query(self, state: RAGState) -> dict[str, list[str]]:
        """把用户问题改写成多路查询,并保留原始问题。"""
        question = state["question"]
        rewritten_queries = self.chain.query_rewriter.generate_multi_queries(question)
        queries = [question] + rewritten_queries
        return {"queries": queries}

    def _retrieve_docs(self, state: RAGState) -> dict[str, list[Document]]:
        """使用现有混合检索器召回候选子 chunk。"""
        docs = self.chain._retrieve_docs(state["queries"])
        return {"docs": docs}

    def _rerank_docs(self, state: RAGState) -> dict[str, list[Document]]:
        """使用 reranker 对候选子 chunk 重排序,并只保留 top_k 个。"""
        reranked = self.chain._reranker.rerank(
            state["question"],
            state["docs"],
            top_k=self.chain.config.top_k,
        )
        docs = [doc for doc, _ in reranked]
        return {"docs": docs}

    def _expand_parent_chunks(self, state: RAGState) -> dict[str, list[Document]]:
        """根据命中的子 chunk,补充有限窗口内的父级上下文。"""
        docs = self.chain._expand_parent_chunks(state["docs"])
        return {"docs": docs}

    def _format_context(self, state: RAGState) -> dict[str, str]:
        """把最终文档列表格式化成 prompt 里的参考文档。"""
        context = self.chain._format_docs(state["docs"])
        return {"context": context}

    def _generate_answer(self, state: RAGState) -> dict[str, str]:
        """调用 LLM 基于上下文生成答案。"""
        prompt = self.chain.prompt_template.format(
            context=state["context"],
            question=state["question"],
        )
        response = self.chain.llm.invoke(prompt)
        answer = StrOutputParser().invoke(response)
        return {"answer": answer}

    def _build_sources(self, state: RAGState) -> dict[str, list[dict[str, Any]]]:
        """构建和 ask_with_source 类似的来源信息"""
        sources = [
            {
                "content": doc.page_content[:200] + "..."
                if len(doc.page_content) > 200
                else doc.page_content,
                "source": doc.metadata.get("file_name", "未知"),
                "metadata": doc.metadata,
            }
            for doc in state["docs"]
        ]
        return {"sources": sources}

    def ask(self, question: str) -> str:
        """只返回答案"""
        result = self.app.invoke({"question": question})
        return result["answer"]

    def ask_with_source(self, question: str) -> dict[str, Any]:
        """返回答案和来源"""
        result = self.app.invoke({"question": question})
        return {
            "answer": result["answer"],
            "sources": result["sources"],
            "question": question,
        }