引言

检索增强生成(Retrieval-Augmented Generation, RAG)是提升LLM准确性和时效性的关键技术。本文详细介绍如何构建生产级的RAG Agent系统,包括文档处理、向量存储、检索优化和生成策略。

1. RAG架构设计

flowchart TB
    subgraph "RAG系统架构"
        U[用户查询] --> QP[查询处理]
        
        subgraph "检索模块"
            QP --> VR[向量检索]
            QP --> KR[关键词检索]
            VR --> HM[混合匹配]
            KR --> HM
        end
        
        subgraph "知识库"
            DS[(文档存储)]
            VS[(向量数据库)]
            IS[(倒排索引)]
        end
        
        HM --> RR[重排序]
        RR --> CA[上下文聚合]
        
        subgraph "生成模块"
            CA --> PE[提示工程]
            PE --> LLM[大语言模型]
            LLM --> PP[后处理]
        end
        
        PP --> R[响应输出]
        
        DS -.-> VR
        VS -.-> VR
        IS -.-> KR
    end
    
    style U fill:#e8f5e9,stroke:#4caf50,stroke-width:2px
    style R fill:#fff3e0,stroke:#ff9800,stroke-width:2px
    style LLM fill:#e3f2fd,stroke:#2196f3,stroke-width:3px

1.1 系统架构

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import List, Dict, Any, Optional
import numpy as np
from dataclasses import dataclass

@dataclass
class RAGConfig:
    chunk_size: int = 1000
    chunk_overlap: int = 200
    embedding_model: str = "text-embedding-ada-002"
    vector_store: str = "chromadb"
    retrieval_top_k: int = 5
    rerank_top_k: int = 3
    temperature: float = 0.7

class RAGAgent:
    def __init__(self, config: RAGConfig):
        self.config = config
        self.document_processor = DocumentProcessor(config)
        self.vector_store = self._init_vector_store()
        self.retriever = HybridRetriever(self.vector_store)
        self.reranker = CrossEncoderReranker()
        self.generator = AugmentedGenerator()
        
    def _init_vector_store(self):
        """初始化向量存储"""
        if self.config.vector_store == "chromadb":
            from chromadb import Client
            return ChromaVectorStore(Client())
        elif self.config.vector_store == "pinecone":
            import pinecone
            return PineconeVectorStore(pinecone)
        elif self.config.vector_store == "weaviate":
            import weaviate
            return WeaviateVectorStore(weaviate.Client())

1.2 文档处理管道

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class DocumentProcessor:
    def __init__(self, config: RAGConfig):
        self.config = config
        self.text_splitter = self._init_splitter()
        self.metadata_extractor = MetadataExtractor()
        
    def _init_splitter(self):
        """初始化文本分割器"""
        from langchain.text_splitter import RecursiveCharacterTextSplitter
        
        return RecursiveCharacterTextSplitter(
            chunk_size=self.config.chunk_size,
            chunk_overlap=self.config.chunk_overlap,
            separators=["\n\n", "\n", "。", "!", "?", ".", "!", "?", " ", ""],
            length_function=len
        )
    
    def process_documents(self, documents: List[Dict]) -> List[Dict]:
        """处理文档"""
        processed_chunks = []
        
        for doc in documents:
            # 提取元数据
            metadata = self.metadata_extractor.extract(doc)
            
            # 预处理文本
            cleaned_text = self.preprocess_text(doc["content"])
            
            # 智能分块
            chunks = self.smart_chunking(cleaned_text, metadata)
            
            # 添加上下文
            chunks_with_context = self.add_context(chunks)
            
            processed_chunks.extend(chunks_with_context)
        
        return processed_chunks
    
    def smart_chunking(self, text: str, metadata: Dict) -> List[str]:
        """智能分块策略"""
        # 根据文档类型选择分块策略
        doc_type = metadata.get("type", "general")
        
        if doc_type == "code":
            return self.chunk_code(text)
        elif doc_type == "table":
            return self.chunk_table(text)
        elif doc_type == "conversation":
            return self.chunk_conversation(text)
        else:
            return self.text_splitter.split_text(text)
    
    def chunk_code(self, code: str) -> List[str]:
        """代码分块"""
        import ast
        chunks = []
        
        try:
            tree = ast.parse(code)
            for node in ast.walk(tree):
                if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
                    chunk = ast.get_source_segment(code, node)
                    if chunk:
                        chunks.append(chunk)
        except:
            # 回退到普通分块
            chunks = self.text_splitter.split_text(code)
        
        return chunks

2. 向量存储与索引

2.1 多级索引策略

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class HierarchicalIndex:
    def __init__(self):
        self.document_index = {}  # 文档级索引
        self.chunk_index = {}     # 块级索引
        self.summary_index = {}   # 摘要索引
        
    def build_index(self, documents: List[Dict]):
        """构建多级索引"""
        for doc in documents:
            doc_id = doc["id"]
            
            # 文档级索引
            self.document_index[doc_id] = {
                "title": doc.get("title", ""),
                "summary": self.generate_summary(doc["content"]),
                "metadata": doc.get("metadata", {}),
                "chunk_ids": []
            }
            
            # 块级索引
            chunks = self.create_chunks(doc["content"])
            for i, chunk in enumerate(chunks):
                chunk_id = f"{doc_id}_chunk_{i}"
                self.chunk_index[chunk_id] = {
                    "content": chunk,
                    "doc_id": doc_id,
                    "position": i,
                    "embedding": None  # 后续填充
                }
                self.document_index[doc_id]["chunk_ids"].append(chunk_id)
            
            # 摘要索引
            summary_embedding = self.embed_text(
                self.document_index[doc_id]["summary"]
            )
            self.summary_index[doc_id] = summary_embedding
    
    def generate_summary(self, content: str) -> str:
        """生成文档摘要"""
        from transformers import pipeline
        
        summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
        summary = summarizer(content, max_length=150, min_length=50)[0]["summary_text"]
        return summary

2.2 向量数据库操作

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class VectorStore:
    def __init__(self, embedding_model):
        self.embedding_model = embedding_model
        self.index = None
        self.metadata = {}
        
    async def add_documents(self, documents: List[Dict]):
        """添加文档到向量存储"""
        embeddings = []
        metadatas = []
        
        for doc in documents:
            # 生成嵌入
            embedding = await self.embedding_model.embed(doc["content"])
            embeddings.append(embedding)
            
            # 保存元数据
            metadata = {
                "id": doc["id"],
                "source": doc.get("source", ""),
                "timestamp": doc.get("timestamp", time.time()),
                "type": doc.get("type", "text")
            }
            metadatas.append(metadata)
        
        # 批量插入
        await self.batch_insert(embeddings, metadatas)
    
    async def batch_insert(self, embeddings: List[np.ndarray], 
                          metadatas: List[Dict]):
        """批量插入向量"""
        batch_size = 100
        
        for i in range(0, len(embeddings), batch_size):
            batch_embeddings = embeddings[i:i+batch_size]
            batch_metadatas = metadatas[i:i+batch_size]
            
            # 插入到向量数据库
            ids = self.index.add(
                embeddings=batch_embeddings,
                metadatas=batch_metadatas
            )
            
            # 更新元数据映射
            for id, metadata in zip(ids, batch_metadatas):
                self.metadata[id] = metadata

3. 混合检索策略

graph LR
    subgraph "混合检索流程"
        Q[查询] --> QE[查询扩展]
        
        QE --> V[向量检索]
        QE --> K[关键词检索]
        QE --> S[语义缓存]
        
        V --> VR[向量结果]
        K --> KR[关键词结果]
        S --> SR[缓存结果]
        
        VR --> F[结果融合
RRF算法] KR --> F SR --> F F --> R[重排序] R --> T[Top-K结果] end style Q fill:#e1f5fe,stroke:#01579b,stroke-width:2px style T fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px style F fill:#fff3e0,stroke:#e65100,stroke-width:2px

3.1 多路检索

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class HybridRetriever:
    def __init__(self, vector_store):
        self.vector_store = vector_store
        self.bm25_index = BM25Index()
        self.semantic_cache = {}
        
    async def retrieve(self, query: str, top_k: int = 5) -> List[Dict]:
        """混合检索"""
        # 1. 向量检索
        vector_results = await self.vector_search(query, top_k * 2)
        
        # 2. 关键词检索
        keyword_results = self.keyword_search(query, top_k * 2)
        
        # 3. 语义缓存检索
        cache_results = self.search_cache(query)
        
        # 4. 融合结果
        fused_results = self.fuse_results(
            vector_results, 
            keyword_results,
            cache_results
        )
        
        return fused_results[:top_k]
    
    async def vector_search(self, query: str, top_k: int) -> List[Dict]:
        """向量相似度检索"""
        query_embedding = await self.embed_query(query)
        
        results = self.vector_store.similarity_search(
            query_embedding,
            k=top_k,
            filter=self.build_filter(query)
        )
        
        return results
    
    def keyword_search(self, query: str, top_k: int) -> List[Dict]:
        """BM25关键词检索"""
        # 分词
        tokens = self.tokenize(query)
        
        # BM25评分
        scores = self.bm25_index.get_scores(tokens)
        
        # 获取top-k
        top_indices = np.argsort(scores)[-top_k:][::-1]
        
        results = []
        for idx in top_indices:
            results.append({
                "content": self.bm25_index.documents[idx],
                "score": scores[idx],
                "type": "keyword"
            })
        
        return results
    
    def fuse_results(self, *result_sets) -> List[Dict]:
        """结果融合(RRF算法)"""
        from collections import defaultdict
        
        k = 60  # RRF常数
        scores = defaultdict(float)
        documents = {}
        
        for results in result_sets:
            for rank, result in enumerate(results):
                doc_id = result.get("id", str(result))
                scores[doc_id] += 1 / (k + rank + 1)
                documents[doc_id] = result
        
        # 按融合分数排序
        sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
        
        fused_results = []
        for doc_id in sorted_ids:
            doc = documents[doc_id]
            doc["fusion_score"] = scores[doc_id]
            fused_results.append(doc)
        
        return fused_results

3.2 查询扩展

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class QueryExpander:
    def __init__(self, llm):
        self.llm = llm
        self.synonym_dict = self.load_synonyms()
        
    async def expand_query(self, query: str) -> List[str]:
        """扩展查询"""
        expanded_queries = [query]
        
        # 1. 同义词扩展
        synonyms = self.get_synonyms(query)
        expanded_queries.extend(synonyms)
        
        # 2. LLM改写
        rephrased = await self.rephrase_query(query)
        expanded_queries.extend(rephrased)
        
        # 3. 假设文档生成(HyDE)
        hypothetical = await self.generate_hypothetical_document(query)
        expanded_queries.append(hypothetical)
        
        return expanded_queries
    
    async def rephrase_query(self, query: str) -> List[str]:
        """使用LLM改写查询"""
        prompt = f"""
        请将以下查询改写成3个不同的版本,保持语义相同:
        原始查询:{query}
        
        改写版本:
        1.
        2.
        3.
        """
        
        response = await self.llm.agenerate([prompt])
        rephrased = self.parse_rephrased(response)
        
        return rephrased
    
    async def generate_hypothetical_document(self, query: str) -> str:
        """生成假设文档"""
        prompt = f"""
        假设你正在写一个文档来回答以下问题:
        {query}
        
        请写出这个文档的第一段:
        """
        
        response = await self.llm.agenerate([prompt])
        return response.generations[0][0].text

4. 重排序机制

4.1 交叉编码器重排

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class CrossEncoderReranker:
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        from sentence_transformers import CrossEncoder
        self.model = CrossEncoder(model_name)
        
    def rerank(self, query: str, documents: List[Dict], top_k: int = 3) -> List[Dict]:
        """重排序文档"""
        # 准备输入对
        pairs = []
        for doc in documents:
            pairs.append([query, doc["content"]])
        
        # 计算相关性分数
        scores = self.model.predict(pairs)
        
        # 添加分数并排序
        for doc, score in zip(documents, scores):
            doc["rerank_score"] = float(score)
        
        # 按分数排序
        reranked = sorted(documents, key=lambda x: x["rerank_score"], reverse=True)
        
        return reranked[:top_k]

4.2 多样性重排

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class DiversityReranker:
    def __init__(self, lambda_param: float = 0.5):
        self.lambda_param = lambda_param
        
    def rerank_with_diversity(self, documents: List[Dict], top_k: int = 5) -> List[Dict]:
        """MMR(最大边际相关性)重排"""
        import numpy as np
        from sklearn.metrics.pairwise import cosine_similarity
        
        # 获取文档嵌入
        embeddings = np.array([doc["embedding"] for doc in documents])
        
        # 初始化
        selected = []
        selected_embeddings = []
        remaining = list(range(len(documents)))
        
        # 选择第一个文档(相关性最高)
        first_idx = 0
        selected.append(first_idx)
        selected_embeddings.append(embeddings[first_idx])
        remaining.remove(first_idx)
        
        # 迭代选择
        while len(selected) < min(top_k, len(documents)):
            mmr_scores = []
            
            for idx in remaining:
                # 相关性分数
                relevance = documents[idx].get("score", 0)
                
                # 与已选文档的最大相似度
                similarities = cosine_similarity(
                    [embeddings[idx]], 
                    selected_embeddings
                )[0]
                max_similarity = max(similarities) if similarities.size > 0 else 0
                
                # MMR分数
                mmr = self.lambda_param * relevance - (1 - self.lambda_param) * max_similarity
                mmr_scores.append((idx, mmr))
            
            # 选择MMR最高的文档
            best_idx = max(mmr_scores, key=lambda x: x[1])[0]
            selected.append(best_idx)
            selected_embeddings.append(embeddings[best_idx])
            remaining.remove(best_idx)
        
        # 返回重排后的文档
        return [documents[idx] for idx in selected]

5. 生成策略优化

5.1 上下文优化

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class ContextOptimizer:
    def __init__(self, max_tokens: int = 4000):
        self.max_tokens = max_tokens
        
    def optimize_context(self, query: str, documents: List[Dict]) -> str:
        """优化上下文"""
        # 1. 压缩文档
        compressed_docs = self.compress_documents(documents)
        
        # 2. 排序和截断
        prioritized_docs = self.prioritize_content(query, compressed_docs)
        
        # 3. 格式化上下文
        context = self.format_context(prioritized_docs)
        
        # 4. 确保不超过token限制
        context = self.truncate_to_limit(context)
        
        return context
    
    def compress_documents(self, documents: List[Dict]) -> List[Dict]:
        """压缩文档内容"""
        compressed = []
        
        for doc in documents:
            # 提取关键句子
            key_sentences = self.extract_key_sentences(doc["content"])
            
            compressed.append({
                **doc,
                "compressed_content": " ".join(key_sentences)
            })
        
        return compressed
    
    def extract_key_sentences(self, text: str, num_sentences: int = 3) -> List[str]:
        """提取关键句子"""
        from sentence_transformers import SentenceTransformer
        from sklearn.metrics.pairwise import cosine_similarity
        import numpy as np
        
        # 分句
        sentences = text.split("。")
        if len(sentences) <= num_sentences:
            return sentences
        
        # 计算句子嵌入
        model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
        embeddings = model.encode(sentences)
        
        # 计算中心性
        similarity_matrix = cosine_similarity(embeddings)
        scores = similarity_matrix.sum(axis=1)
        
        # 选择得分最高的句子
        top_indices = np.argsort(scores)[-num_sentences:]
        
        # 保持原始顺序
        top_indices = sorted(top_indices)
        
        return [sentences[i] for i in top_indices]

5.2 增强生成

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class AugmentedGenerator:
    def __init__(self, llm):
        self.llm = llm
        self.prompt_templates = self.load_templates()
        
    async def generate(self, query: str, context: str, 
                      generation_config: Dict = None) -> str:
        """增强生成"""
        # 1. 选择提示模板
        template = self.select_template(query)
        
        # 2. 构建提示
        prompt = self.build_prompt(template, query, context)
        
        # 3. 生成回答
        response = await self.llm.agenerate(
            [prompt],
            **generation_config or {}
        )
        
        # 4. 后处理
        answer = self.postprocess(response.generations[0][0].text)
        
        # 5. 验证答案
        if not self.validate_answer(answer, context):
            answer = await self.regenerate_with_constraints(query, context)
        
        return answer
    
    def build_prompt(self, template: str, query: str, context: str) -> str:
        """构建提示"""
        return template.format(
            context=context,
            query=query,
            current_date=datetime.now().strftime("%Y-%m-%d")
        )
    
    def validate_answer(self, answer: str, context: str) -> bool:
        """验证答案"""
        # 检查是否基于上下文
        if "根据提供的信息" not in answer and len(answer) > 100:
            return False
        
        # 检查是否包含幻觉
        facts_in_answer = self.extract_facts(answer)
        facts_in_context = self.extract_facts(context)
        
        for fact in facts_in_answer:
            if not self.fact_supported(fact, facts_in_context):
                return False
        
        return True

6. 评估与优化

6.1 检索质量评估

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class RetrievalEvaluator:
    def __init__(self):
        self.metrics = {}
        
    def evaluate(self, queries: List[str], ground_truth: List[List[str]], 
                retrieved: List[List[str]]) -> Dict:
        """评估检索质量"""
        metrics = {
            "mrr": self.calculate_mrr(ground_truth, retrieved),
            "map": self.calculate_map(ground_truth, retrieved),
            "ndcg": self.calculate_ndcg(ground_truth, retrieved),
            "recall@k": {},
            "precision@k": {}
        }
        
        for k in [1, 3, 5, 10]:
            metrics["recall@k"][k] = self.calculate_recall_at_k(
                ground_truth, retrieved, k
            )
            metrics["precision@k"][k] = self.calculate_precision_at_k(
                ground_truth, retrieved, k
            )
        
        return metrics
    
    def calculate_mrr(self, ground_truth: List[List[str]], 
                     retrieved: List[List[str]]) -> float:
        """计算MRR(平均倒数排名)"""
        mrr = 0
        for gt, ret in zip(ground_truth, retrieved):
            for i, doc in enumerate(ret):
                if doc in gt:
                    mrr += 1 / (i + 1)
                    break
        
        return mrr / len(ground_truth)
    
    def calculate_ndcg(self, ground_truth: List[List[str]], 
                      retrieved: List[List[str]], k: int = 10) -> float:
        """计算NDCG"""
        import numpy as np
        
        def dcg(relevances, k):
            relevances = np.array(relevances)[:k]
            if relevances.size:
                return np.sum(relevances / np.log2(np.arange(2, relevances.size + 2)))
            return 0
        
        ndcg_scores = []
        for gt, ret in zip(ground_truth, retrieved):
            relevances = [1 if doc in gt else 0 for doc in ret[:k]]
            ideal_relevances = [1] * min(len(gt), k) + [0] * (k - min(len(gt), k))
            
            dcg_score = dcg(relevances, k)
            idcg_score = dcg(ideal_relevances, k)
            
            if idcg_score > 0:
                ndcg_scores.append(dcg_score / idcg_score)
            else:
                ndcg_scores.append(0)
        
        return np.mean(ndcg_scores)

6.2 生成质量评估

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class GenerationEvaluator:
    def __init__(self):
        self.faithfulness_model = self.load_faithfulness_model()
        self.relevance_model = self.load_relevance_model()
        
    async def evaluate_generation(self, query: str, context: str, 
                                 answer: str) -> Dict:
        """评估生成质量"""
        metrics = {
            "faithfulness": await self.evaluate_faithfulness(answer, context),
            "relevance": await self.evaluate_relevance(answer, query),
            "completeness": self.evaluate_completeness(answer, query),
            "fluency": self.evaluate_fluency(answer)
        }
        
        return metrics
    
    async def evaluate_faithfulness(self, answer: str, context: str) -> float:
        """评估忠实度"""
        prompt = f"""
        请评估答案是否忠实于给定的上下文。
        
        上下文:{context}
        答案:{answer}
        
        评分(0-1):
        """
        
        response = await self.faithfulness_model.agenerate([prompt])
        score = self.extract_score(response)
        
        return score
    
    def evaluate_fluency(self, answer: str) -> float:
        """评估流畅度"""
        from transformers import GPT2LMHeadModel, GPT2Tokenizer
        import torch
        
        model = GPT2LMHeadModel.from_pretrained('gpt2')
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        
        inputs = tokenizer(answer, return_tensors='pt')
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
            perplexity = torch.exp(loss)
        
        # 将困惑度转换为0-1分数
        fluency_score = 1 / (1 + perplexity.item() / 100)
        
        return fluency_score

7. 生产部署

7.1 API服务

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import asyncio

app = FastAPI()

class RAGRequest(BaseModel):
    query: str
    top_k: int = 5
    temperature: float = 0.7
    max_tokens: int = 500

class RAGResponse(BaseModel):
    answer: str
    sources: List[Dict]
    confidence: float

# 初始化RAG系统
rag_agent = RAGAgent(RAGConfig())

@app.post("/rag/query", response_model=RAGResponse)
async def query_rag(request: RAGRequest):
    try:
        # 检索相关文档
        documents = await rag_agent.retriever.retrieve(
            request.query, 
            request.top_k
        )
        
        # 重排序
        reranked = rag_agent.reranker.rerank(
            request.query,
            documents,
            top_k=3
        )
        
        # 生成答案
        answer = await rag_agent.generator.generate(
            request.query,
            reranked,
            {
                "temperature": request.temperature,
                "max_tokens": request.max_tokens
            }
        )
        
        return RAGResponse(
            answer=answer,
            sources=[{"content": doc["content"][:200], "score": doc.get("score", 0)} 
                    for doc in reranked],
            confidence=calculate_confidence(reranked)
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/rag/index")
async def index_documents(documents: List[Dict]):
    """索引新文档"""
    try:
        processed = rag_agent.document_processor.process_documents(documents)
        await rag_agent.vector_store.add_documents(processed)
        return {"status": "success", "indexed": len(processed)}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

7.2 性能优化

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class RAGOptimizer:
    def __init__(self, rag_agent: RAGAgent):
        self.rag_agent = rag_agent
        self.cache = LRUCache(maxsize=1000)
        self.query_embeddings_cache = {}
        
    async def optimized_retrieve(self, query: str) -> List[Dict]:
        """优化的检索"""
        # 1. 检查缓存
        cache_key = self.get_cache_key(query)
        if cache_key in self.cache:
            return self.cache[cache_key]
        
        # 2. 批量嵌入优化
        if query in self.query_embeddings_cache:
            query_embedding = self.query_embeddings_cache[query]
        else:
            query_embedding = await self.batch_embed([query])[0]
            self.query_embeddings_cache[query] = query_embedding
        
        # 3. 并行检索
        tasks = [
            self.rag_agent.vector_store.search(query_embedding),
            self.rag_agent.bm25_index.search(query)
        ]
        
        vector_results, keyword_results = await asyncio.gather(*tasks)
        
        # 4. 融合和缓存结果
        results = self.rag_agent.retriever.fuse_results(
            vector_results,
            keyword_results
        )
        
        self.cache[cache_key] = results
        
        return results

8. 高级特性

8.1 增量学习

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class IncrementalLearningRAG:
    def __init__(self):
        self.feedback_buffer = []
        self.update_frequency = 100
        
    async def learn_from_feedback(self, query: str, answer: str, 
                                 feedback: Dict):
        """从反馈中学习"""
        self.feedback_buffer.append({
            "query": query,
            "answer": answer,
            "feedback": feedback,
            "timestamp": time.time()
        })
        
        if len(self.feedback_buffer) >= self.update_frequency:
            await self.update_model()
    
    async def update_model(self):
        """更新模型"""
        # 1. 分析反馈
        positive_examples = []
        negative_examples = []
        
        for item in self.feedback_buffer:
            if item["feedback"]["rating"] >= 4:
                positive_examples.append(item)
            else:
                negative_examples.append(item)
        
        # 2. 生成训练数据
        training_data = self.prepare_training_data(
            positive_examples,
            negative_examples
        )
        
        # 3. 微调嵌入模型
        await self.finetune_embeddings(training_data)
        
        # 4. 更新检索策略
        self.update_retrieval_strategy(training_data)
        
        # 清空缓冲区
        self.feedback_buffer = []

8.2 多模态RAG

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class MultiModalRAG:
    def __init__(self):
        self.text_encoder = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.cross_modal_encoder = self.load_cross_modal_encoder()
        
    async def process_multimodal_query(self, query: Dict) -> str:
        """处理多模态查询"""
        text_query = query.get("text", "")
        image_query = query.get("image")
        
        # 1. 编码查询
        if text_query and image_query:
            query_embedding = await self.encode_multimodal(
                text_query,
                image_query
            )
        elif text_query:
            query_embedding = await self.text_encoder.encode(text_query)
        elif image_query:
            query_embedding = await self.image_encoder.encode(image_query)
        
        # 2. 多模态检索
        results = await self.retrieve_multimodal(query_embedding)
        
        # 3. 生成答案
        answer = await self.generate_from_multimodal(results, query)
        
        return answer

9. 最佳实践

  1. 文档预处理:清洗、去重、格式统一
  2. 智能分块:根据文档类型选择分块策略
  3. 混合检索:结合向量和关键词检索
  4. 重排序:使用交叉编码器提升精度
  5. 上下文优化:压缩和优先级排序
  6. 缓存策略:多级缓存提升性能
  7. 持续优化:基于反馈改进系统

结论

RAG Agent通过结合检索和生成,显著提升了LLM的准确性和实用性。关键在于优化检索质量、上下文管理和生成策略。

参考资源