From 22f370322412074174cde20ecfd14ec03657ab63 Mon Sep 17 00:00:00 2001 From: lyg <1543117173@qq.com> Date: 星期一, 07 七月 2025 16:20:25 +0800 Subject: [PATCH] 生成数据库 --- knowledgebase/db/doc_db_helper.py | 84 ++++++++++++++++++++++++++++++------------ 1 files changed, 60 insertions(+), 24 deletions(-) diff --git a/knowledgebase/db/doc_db_helper.py b/knowledgebase/db/doc_db_helper.py index 6b8d6cb..a0b82b1 100644 --- a/knowledgebase/db/doc_db_helper.py +++ b/knowledgebase/db/doc_db_helper.py @@ -6,6 +6,7 @@ # @description: 鏂囨。鏁版嵁搴撳姪鎵嬶紝mysql鏁版嵁搴� import json +from threading import RLock from knowledgebase.db.doc_db_models import init_doc_db, TDoc, TEntity, TParagraph, TParagraphLink, TParagraphRefLink, \ TParagraphEntityLink @@ -17,9 +18,13 @@ """ 鏂囨。鏁版嵁搴撳姪鎵� """ + lock = RLock() def __init__(self): - self.session = init_doc_db() + self.session = None + + def set_project_path(self, project_path): + self.session = init_doc_db(project_path) def add_doc(self, doc_info: DocInfo) -> int: """ @@ -32,6 +37,7 @@ ) self.session.add(_doc) self.session.commit() + doc_info.id = _doc.id return _doc.id def add_paragraph(self, doc_id: int, parent_id: int, paragraph_info: ParagraphInfo) -> TParagraph: @@ -62,6 +68,7 @@ if paragraph_info.children: for child in paragraph_info.children: self.add_paragraph(doc_id, _paragraph.id, child) + paragraph_info.id = _paragraph.id return _paragraph def add_paragraph_link(self, paragraph_link): @@ -73,6 +80,18 @@ self.session.commit() return paragraph_link.id + def add_paragraph_ref_link(self, paren_id: int, child_id: int) -> int: + """ + 娣诲姞娈佃惤寮曠敤鍏崇郴 + :param paren_id: 寮曠敤娈佃惤 + :param child_id: 琚紩鐢ㄦ钀� + :return: + """ + link = TParagraphRefLink(parent_id=paren_id, child_id=child_id, is_del=0) + self.session.add(link) + self.session.commit() + return link.id + def add_paragraph_entity_link(self, paragraph_entity_link): """ 娣诲姞娈佃惤瀹炰綋鍏崇郴 @@ -83,10 +102,11 @@ return paragraph_entity_link.id def get_entity(self, entity): - ret = self.session.query(TEntity).where( - TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first() - if ret: - return ret + with self.lock: + ret = self.session.query(TEntity).where( + TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first() + if ret: + return ret def add_entity(self, entity): """ @@ -97,20 +117,14 @@ self.session.commit() return entity.id - def add_paragraph_ref_link(self, paragraph_ref_link): - """ - 娣诲姞娈佃惤寮曠敤鍏崇郴 - :param paragraph_ref_link: 娈佃惤寮曠敤鍏崇郴 - """ - self.session.add(paragraph_ref_link) - self.session.commit() - return paragraph_ref_link - def get_all_entities(self) -> list[TEntity]: - return self.session.query(TEntity).all() + with self.lock: + return self.session.query(TEntity).all() def get_docs(self) -> list[TDoc]: - return self.session.query(TDoc).all() + with self.lock: + return self.session.query(TDoc).all() + def get_texts_with_entities(self, entity_names: list[str]): """ @@ -118,13 +132,29 @@ :param entity_names: list[str] - 瀹炰綋璇� :return: list[str] - 鏂囨湰鍒楄〃 """ - if not entity_names: - return "" - _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all() - _entitie_ids = [entity.id for entity in _entities] - links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all() - _paragraphs = [link.paragraph for link in links] - return [self.get_paragraph_full_text(p) for p in _paragraphs] + with self.lock: + if not entity_names: + return "" + _entities = self.session.query(TEntity).where(TEntity.name.in_(entity_names)).all() + _entitie_ids = [entity.id for entity in _entities] + links = self.session.query(TParagraphEntityLink).where( + TParagraphEntityLink.entity_id.in_(_entitie_ids)).all() + _paragraphs:[TParagraph] = [link.paragraph for link in links] + ref_paragraphs = [] + for p in _paragraphs: + ref_paragraphs.extend([x.child for x in p.ref_links]) + _paragraphs.extend(ref_paragraphs) + id_map = {} + result = [] + for p in _paragraphs: + if p.id in id_map: + continue + else: + id_map[p.id] = p + result.append(p) + return [p.text for p in result] + + def get_text_with_entities(self, entity_names: list[str]) -> str: """ 鏍规嵁瀹炰綋璇嶈幏鍙栨枃鏈唴瀹� @@ -143,7 +173,13 @@ return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children]) def get_entities_by_doc_type(self, doc_type): - _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all() + with self.lock: + _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all() + return _entities + + def get_entities_by_type(self, ty: str)->list[TEntity]: + with self.lock: + _entities = self.session.query(TEntity).where(TEntity.type == ty).all() return _entities def commit(self): -- Gitblit v1.9.1