知识图谱检索增强的GraphRAG(基于Neo4j代码实现)

为了存储知识图谱数据,首先需要搭建一个 Neo4j 实例。最简单的方法是在 Neo4j Aura 上
首页 新闻资讯 行业资讯 知识图谱检索增强的GraphRAG(基于Neo4j代码实现)

图检索增强生成(Graph RAG)正逐渐流行起来,成为传统向量搜索方法的有力补充。这种方法利用图数据库的结构化特性,将数据以节点和关系的形式组织起来,从而增强检索信息的深度和上下文关联性。图在表示和存储多样化且相互关联的信息方面具有天然优势,能够轻松捕捉不同数据类型间的复杂关系和属性。而向量数据库在处理这类结构化信息时则显得力不从心,它们更擅长通过高维向量处理非结构化数据。在 RAG 应用中,结合结构化的图数据和非结构化的文本向量搜索,可以让我们同时享受两者的优势,这也是本文将要探讨的内容。


构建知识图谱通常是利用图数据表示的强大功能中最困难的一步。它需要收集和整理数据,这需要对领域知识和图建模有深刻的理解。为了简化这一过程,可以参考已有的项目或者利用LLM来创建知识图谱,进而可以把重点放在检索召回,以及LLM的生成阶段。下面来进行相关代码的实践。

1.知识图谱构建

为了存储知识图谱数据,首先需要搭建一个 Neo4j 实例。最简单的方法是在 Neo4j Aura 上启动一个免费实例,它提供了 Neo4j 数据库的云版本。当然,也可以通过docker本地启动一个,然后将图谱数据导入到Neo4j 数据库中。

步骤I:Neo4j环境搭建

下面是本地启动docker的运行示例:

docker run-d \--restart always \--publish=7474:7474 --publish=7687:7687 \--env NEO4J_AUTH=neo4j/000000 \--volume=/yourdockerVolume/neo4j:/data \neo4j:5.19.0

步骤II:图谱数据导入

演示中,我们可以使用伊丽莎白一世的维基百科页面。利用 LangChain 加载器从维基百科获取并分割文档,后存入Neo4j数据库。为了试验中文上的效果,我们导入这个Github上的这个项目(QASystemOnMedicalKG)中的医学知识图谱,包含近35000个节点,30万组三元组,大致得到如下结果:

图片图片

或者利用LangChainLangChain 加载器从维基百科获取并分割文档,大致如下面步骤所示:

# 读取维基百科文章raw_documents=WikipediaLoader(query="Elizabeth I").load()# 定义分块策略text_splitter=TokenTextSplitter(chunk_size=512,chunk_overlap=24)documents=text_splitter.split_documents(raw_documents[:3])llm=ChatOpenAI(temperature=0,model_name="gpt-4-0125-preview")llm_transformer=LLMGraphTransformer(llm=llm)# 提取图数据graph_documents=llm_transformer.convert_to_graph_documents(documents)# 存储到 neo4jgraph.add_graph_documents(graph_documents,baseEntityLabel=True,include_source=True)

2.知识图谱检索

在对知识图谱检索之前,需要对实体和相关属性进行向量嵌入并存储到Neo4j数据库中:

  • 实体信息向量嵌入:将实体名称和实体的描述信息拼接后,利用向量表征模型进行向量嵌入(如下述示例代码中的add_embeddings方法所示)。

  • 图谱结构化检索:图谱的结构化检索分为四个步骤:步骤一,从图谱中检索与查询相关的实体;步骤二,从全局索引中检索得到实体的标签;步骤三,根据实体标签在相应的节点中查询邻居节点路径;步骤四,对关系进行筛选,保持多样性(整个检索过程如下述示例代码中的structured_retriever方法所示)。

class GraphRag(object):
    def __init__(self):"""Any embedding function implementing `langchain.embeddings.base.Embeddings` interface."""self._database='neo4j'self.label='Med'self._driver=neo4j.GraphDatabase.driver(uri=os.environ["NEO4J_URI"],auth=(os.environ["NEO4J_USERNAME"],os.environ["NEO4J_PASSWORD"]))self.embeddings_zh=HuggingFaceEmbeddings(model_name=os.environ["EMBEDDING_MODEL"])self.vectstore=Neo4jVector(embedding=self.embeddings_zh,username=os.environ["NEO4J_USERNAME"],password=os.environ["NEO4J_PASSWORD"],url=os.environ["NEO4J_URI"],node_label=self.label,index_name="vector")def query(self,query: str,params: dict={})->List[Dict[str,Any]]:"""Query Neo4j database."""fromneo4j.exceptionsimportCypherSyntaxErrorwithself._driver.session(database=self._database)assession:
            try:data=session.run(query,params)return[r.data()forrindata]exceptCypherSyntaxErrorase:
                raise ValueError(f"Generated Cypher Statement is not valid\n{e}")def add_embeddings(self):"""Add embeddings to Neo4j database."""# 查询图中所有节点,并且根据节点的描述和名字生成embedding,添加到该节点上query="""MATCH (n) WHERE not (n:{}) RETURN ID(n) AS id, labels(n) as labels, n""".format(self.label)print("qurey node...")data=self.query(query)ids,texts,embeddings,metas=[],[],[],[]forrowintqdm(data,desc="parsing node"):
            ids.append(row['id'])text=row['n'].get('name','')+row['n'].get('desc','')texts.append(text)metas.append({"label":row['labels'],"context":text})self.embeddings_zh.multi_process=Falseprint("node embeddings...")embeddings=self.embeddings_zh.embed_documents(texts)print("adding node embeddings...")ids_ret=self.vectstore.add_embeddings(ids=ids,texts=texts,embeddings=embeddings,metadatas=metas)returnids_ret# Fulltext index querydef structured_retriever(self,query,limit=3,simlarity=0.9)->str:"""
        Collects the neighborhood of entities mentioned in the question
        """# step1 从图谱中检索与查询相关的实体。docs_with_score=self.vectstore.similarity_search_with_score(query,k=topk)entities=[item[0].page_contentforitemindataifitem[1]>simlarity]# scoreself.vectstore.query("CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:Med) ON EACH [e.context]")result=""forentityinentities:
            qry=entity# step2 从全局索引中查出entity label,query1=f"""CALL db.index.fulltext.queryNodes('entity', '{qry}') YIELD node, score 
                            return node.label as label,node.context as context, node.id as id, score LIMIT {limit}"""data1=self.vectstore.query(query1)# step3 根据label在相应的节点中查询邻居节点路径foritemindata1:
                node_type=item['label']node_type=item['label']iftype(node_type)==strelsenode_type[0]node_id=item['id']query2=f"""match (node:{node_type})-[r]-(neighbor) where ID(node) = {node_id} RETURN type(r) as rel, node.name+' - '+type(r)+' - '+neighbor.name as output limit 50"""data2=self.vectstore.query(query2)# step4 为了保持多样性,对关系进行筛选rel_dict=defaultdict(list)iflen(data2)>3*limit:foritem1indata2:
                        rel_dict[item1['rel']].append(item1['output'])ifrel_dict:
                    rel_dict={k:random.sample(v,3)iflen(v)>3elsevfork,vinrel_dict.items()}
                    result+="\n".join(['\n'.join(el)forelinrel_dict.values()])+'\n'else:
                    result+="\n".join([el['output']forelindata2])+'\n'returnresult

3.结合LLM生成

最后利用大语言模型(LLM)根据从知识图谱中检索出来的结构化信息,生成最终的回复。下面的代码中我们以通义千问开源的大语言模型为例:

步骤I:加载LLM模型

fromlangchainimportHuggingFacePipelinefromtransformersimportpipeline,AutoTokenizer,AutoModelForCausalLM


def custom_model(model_name,branch_name=None,cache_dir=None,temperature=0,top_p=1,max_new_tokens=512,stream=False):
    tokenizer=AutoTokenizer.from_pretrained(model_name,revision=branch_name,cache_dir=cache_dir)model=AutoModelForCausalLM.from_pretrained(model_name,device_map='auto',torch_dtype=torch.float16,revision=branch_name,cache_dir=cache_dir)pipe=pipeline("text-generation",model=model,tokenizer=tokenizer,torch_dtype=torch.bfloat16,device_map='auto',max_new_tokens=max_new_tokens,do_sample=True)llm=HuggingFacePipeline(pipeline=pipe,model_kwargs={"temperature":temperature,"top_p":top_p,"tokenizer":tokenizer,"model":model})returnllm
    
tongyi_model="Qwen1.5-7B-Chat"llm_model=custom_model(model_name=tongyi_model)tokenizer=llm_model.model_kwargs['tokenizer']model=llm_model.model_kwargs['model']


步骤II:输入检索数据生成回复

final_data=self.get_retrieval_data(query)prompt=("请结合以下信息,简洁和专业的来回答用户的问题,若信息与问题关联紧密,请尽量参考已知信息。\n""已知相关信息:\n{context} 请回答以下问题:{question}".format(cnotallow=final_data,questinotallow=query))messages=[{"role":"system","content":"你是**开发的智能助手。"},{"role":"user","content": prompt}]text=tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)model_inputs=tokenizer([text],return_tensors="pt").to(self.device)generated_ids=model.generate(model_inputs.input_ids,max_new_tokens=512)generated_ids=[output_ids[len(input_ids):]forinput_ids,output_idsinzip(model_inputs.input_ids,generated_ids)]response=tokenizer.batch_decode(generated_ids,skip_special_tokens=True)[0]print(response)

4 结语

对一个查询问题分别进行了测试, 与没有RAG仅利用LLM生成回复的的情况进行对比,在有GraphRAG 的情况下,LLM模型回答的信息量更大、准确会更高。

27    2024-06-03 10:53:18    LLM RAG GraphRAG