返回文章列表

chunker

2 min read
PYTHON
左右滑动查看完整代码
from pathlib import Path
from typing import List, Optional

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from loguru import logger

from config import Config, get_config
from embedder import EmbeddingModel
from chunk_saver import ChunkSaver


class SemanticChunker:
    def __init__(
        self,
        config: Optional[Config] = None,
        breakpoint_threshold: float = 0.5,
        chunks_dir: str = "./data/chunks",
    ):
        """
        初始化语义分块器。
        Args:
            config: 配置对象
            breakpoint_threshold: 语义断点阈值(0-1,越大越容易断开)
            chunks_dir: 分块数据存储目录
        """
        self.config = config or get_config()
        self.breakpoint_threshold = breakpoint_threshold
        self.min_chunk_size = 100
        self.max_chunk_size = 512
        self.chunks_dir = Path(chunks_dir)
        # 创建存储目录。
        self.chunks_dir.mkdir(parents=True, exist_ok=True)
        self.chunk_saver = ChunkSaver(
            chunks_dir=self.chunks_dir,
            markdown_prefix="semantic_chunk",
            markdown_title="Semantic Chunk",
        )
        self._sentence_splitter = RecursiveCharacterTextSplitter(
            chunk_size=100,
            chunk_overlap=0,  # 根据标点裁剪,所以这里不需要 overlap。
            separators=["\n\n", "\n", "。", ".", "!", "!", "?", "?", ";", ";"],
            is_separator_regex=False,
        )
        self._embedding = self._create_embedding()  # 把句子转成向量,用于计算相似度。
        logger.info(
            "语义分块器初始化:threshold={}, min={}, max={}",
            breakpoint_threshold,
            self.min_chunk_size,
            self.max_chunk_size,
        )

    def _create_embedding(self) -> Embeddings:
        """创建嵌入模型实例。"""
        return EmbeddingModel(self.config).embeddings


    # 这里有点看不明白,但是不重要,后面再补,只需要知道这个公式可以计算向量是否相似就行。
    def _compute_similarity(self, vec1: List[float], vec2: List[float]) -> float:
        """计算两个向量的余弦相似度。"""
        arr1 = np.array(vec1)
        arr2 = np.array(vec2)
        return float(np.dot(arr1, arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2)))

    # 根据语义相似度合并文本。
    def _split_text(self, text: str) -> List[str]:
        # 根据标点符号切句。
        sentences = self._sentence_splitter.split_text(text)
        if not sentences:
            return []
        if len(sentences) == 1:
            return sentences
        # 转成向量,方便后面计算相似度。
        vector_list = self._embedding.embed_documents(sentences)
        # 新合并出的文本列表。
        chunks = []
        # 第一段先作为当前块的起点。
        current_chunk = sentences[0]

        for i in range(1, len(sentences)):
            next_sentence = sentences[i]
            similarity = self._compute_similarity(vector_list[i - 1], vector_list[i])
            # 合并条件:
            # 1. 合并后长度不能超过 max_chunk_size;
            # 2. 语义相似,或者当前块还没达到最小长度。
            merged_length = len(current_chunk) + 1 + len(next_sentence)
            should_merge = merged_length <= self.max_chunk_size and (
                similarity > self.breakpoint_threshold
                or len(current_chunk) < self.min_chunk_size
            )
            # 如果可以合并。
            if should_merge:
                current_chunk += " " + next_sentence
                # 这里不 append,因为后面可能还能继续合并。
                continue
            # 不能合并时,说明 current_chunk 已经成型,可以加入结果。
            chunks.append(current_chunk)
            # 下一块从当前句子开始。
            current_chunk = next_sentence

        # 最后还会剩下一个块。
        if current_chunk:
            chunks.append(current_chunk)
        return chunks

    def _split_document(self, documents: List[Document]) -> List[Document]:
        """把文档列表切成 Document chunk 列表。"""
        chunks = []
        logger.info("开始语义分块处理")
        for doc in documents:
            chunk_texts = self._split_text(doc.page_content)
            for chunk_index, chunk_text in enumerate(chunk_texts):
                # 复制 metadata,避免多个 chunk 共享同一个 dict。
                metadata = dict(doc.metadata)
                metadata["chunk_index"] = chunk_index
                metadata["chunk_size"] = len(chunk_text)
                chunks.append(Document(page_content=chunk_text, metadata=metadata))
        return chunks

    def split_documents(self, documents: List[Document]) -> List[Document]:
        """对外入口:优先读缓存,否则执行语义切分并保存结果。"""
        if not documents:
            logger.warning("输入文档列表为空")
            return []

        chunks = self._split_document(documents)
        logger.info("语义切分完成:{} 个 chunk", len(chunks))
        if chunks:
            self.chunk_saver.save_chunks(chunks)
        return chunks

其实像md这种文档,根据标题去切割也可以(###),因为大多数都是分好了的