返回文章列表

QueryRewriter 多路查询

1 min read
PYTHON
左右滑动查看完整代码
"""
查询重写模块

实现多种查询重写策略来提升检索召回率:
- 多路查询生成
- HyDE (假设文档嵌入)
- 查询扩展
"""

from typing import List, Optional
from loguru import logger

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

from config import Config, get_config


class QueryRewriter:
    """
    查询重写器
    
    通过多种策略改写用户查询,提升检索效果。
    """
    
    # 多路查询生成提示词
    MULTI_QUERY_PROMPT = """你是一个专业的搜索查询优化专家。
请根据用户的原始问题,生成 3 个不同角度的搜索查询。
这些查询应该:
1. 保持原始问题的核心意图
2. 使用不同的表达方式
3. 可能包含同义词或相关概念

原始问题:{question}

请输出 3 个改写后的查询,每行一个,不要编号:"""

    # HyDE 提示词:生成假设答案
    HYDE_PROMPT = """请针对以下问题,写一段简短的假设性回答(约 50-100 字)。
这个回答应该像是从一份专业文档中摘录的内容。

问题:{question}

假设回答:"""
    def __init__(self, config: Config):

        kwargs = {
            "model": config.model_name,
            "api_key": config.openai_api_key,
            "temperature": 0.7,
        }
        if config.openai_base_url:
            kwargs["base_url"] = config.openai_base_url
        self.llm =  ChatOpenAI(**kwargs)
        logger.info("查询重写器初始化完成,使用模型:{}", config.model_name)
    def generate_multi_queries(self, question: str) -> List[str]:
        """使用多路查询生成策略改写查询"""
        prompt = ChatPromptTemplate.from_template(self.MULTI_QUERY_PROMPT)
        prompt_value = prompt.invoke({"question": question})
        response = self.llm.invoke(prompt_value)
        content = StrOutputParser().invoke(response)
        return [line.strip() for line in content.splitlines() if line.strip()]
    def generate_hyde_query(self, question: str) -> str:
        """使用 HyDE 策略生成假设答案作为查询的一部分"""
        prompt = ChatPromptTemplate.from_template(self.HYDE_PROMPT)
        prompt_value = prompt.invoke({"question": question})
        response = self.llm.invoke(prompt_value)
        return StrOutputParser().invoke(response).strip()