# -*- coding: utf-8 -*-
|
# entity_recognition.py
|
# @author: lyg
|
# @date: 2025-04-24
|
# @version: 0.0.1
|
# @description: 实体抽取,将文本中的实体进行识别和提取。
|
|
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate
|
from langchain_core.output_parsers import JsonOutputParser
|
import json
|
|
from knowledgebase import utils
|
from knowledgebase.db.doc_db_helper import doc_dbh
|
from knowledgebase.log import Log
|
from knowledgebase.llm import llm
|
|
class EntityRecognition:
|
"""
|
实体识别抽取。
|
|
使用langchain构建实体抽取流程。
|
"""
|
use_cache = False
|
cache_file = "entity_recognition.cache"
|
|
def __init__(self, doc_type: str):
|
# 实体词列表
|
entities = doc_dbh.get_entities_by_doc_type(doc_type)
|
entity_list = ','.join([entity.name for entity in entities]) + "。"
|
entity_rules = ";\n".join([f"- {entity.name}:{entity.prompts}" for entity in entities]) + "。"
|
tpl = """
|
# 指令
|
请根据实体词判断规则从给定的文本中判断是否有下列实体词相关内容,如果有则输出相关的实体词,没有则不输出,实体词列表定义如下:
|
""" + entity_list + """
|
## 实体词判断规则:
|
""" + entity_rules + """
|
# 约束
|
- 输出格式为JSON格式;
|
- 提取的实体词必须是:""" + entity_list + """;
|
- 如果没有符合上述规则的实体词则不要输出任何实体词;
|
- 输出数据结构为字符串数组。
|
# 示例
|
```json
|
[\"""" + entities[0].name + """\"]
|
```
|
|
# 文本如下:
|
{text}
|
"""
|
Log.info(tpl)
|
msg = HumanMessagePromptTemplate.from_template(template=tpl)
|
prompt = ChatPromptTemplate.from_messages([msg])
|
parser = JsonOutputParser(pydantic_object=list[str])
|
self.chain = prompt | llm | parser
|
self.cache = {}
|
self.load_cache()
|
|
def load_cache(self):
|
"""
|
加载缓存。
|
"""
|
if utils.file_exists(self.cache_file):
|
text = utils.read_from_file(self.cache_file)
|
self.cache = json.loads(text)
|
|
def save_cache(self):
|
"""
|
保存缓存。
|
"""
|
text = json.dumps(self.cache)
|
utils.save_text_to_file(text, self.cache_file)
|
|
async def run(self, in_text: str) -> list[str]:
|
"""
|
运行实体识别抽取。
|
:param in_text: str - 输入文本
|
"""
|
# 缓存命中
|
text_md5 = utils.generate_text_md5(in_text)
|
if self.use_cache and text_md5 in self.cache:
|
return self.cache[text_md5]
|
result = await self.chain.ainvoke({"text": in_text})
|
self.cache[text_md5] = result
|
self.save_cache()
|
return result
|