lyg
2025-05-22 e60d75228fb161e464ca59fa2526bf0765f4d902
knowledgebase/db/doc_db_helper.py
@@ -82,6 +82,12 @@
        self.session.commit()
        return paragraph_entity_link.id
    def get_entity(self, entity):
        ret = self.session.query(TEntity).where(
            TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first()
        if ret:
            return ret
    def add_entity(self, entity):
        """
        添加实体
@@ -106,11 +112,11 @@
    def get_docs(self) -> list[TDoc]:
        return self.session.query(TDoc).all()
    def get_text_with_entities(self, entity_names: list[str]) -> str:
    def get_texts_with_entities(self, entity_names: list[str]):
        """
        根据实体词获取文本内容
        根据实体词获取文本内容列表
        :param entity_names: list[str] - 实体词
        :return: str - 文本
        :return: list[str] - 文本列表
        """
        if not entity_names:
            return ""
@@ -118,26 +124,30 @@
        _entitie_ids = [entity.id for entity in _entities]
        links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all()
        _paragraphs = [link.paragraph for link in links]
        return [self.get_paragraph_full_text(p) for p in _paragraphs]
    def get_text_with_entities(self, entity_names: list[str]) -> str:
        """
        根据实体词获取文本内容
        :param entity_names: list[str] - 实体词
        :return: str - 文本
        """
        texts = self.get_texts_with_entities(entity_names)
        return '\n'.join(texts)
        return '\n'.join([self.get_paragraph_full_text(p) for p in _paragraphs])
    def get_entities_by_names(self, names: list[str]):
        _entities = self.session.query(TEntity).where(TEntity.name.in_(names)).all()
        return _entities
    def get_paragraph_full_text(self, p: TParagraph):
        result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text
        return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children])
    def get_entities_by_doc_type(self, doc_type):
        _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all()
        return _entities
    def commit(self):
        self.session.commit()
doc_dbh = DocDbHelper()
# if __name__ == '__main__':
#     text = doc_dbh.get_text_with_entities(['遥控包格式'])
#     print(text)
#     doc_db = DocDbHelper()
#     # doc_db.insert_entities()
#     doc = doc_db.add_doc(DocInfo(file='aaa', file_name='test'))
#     p1 = doc_db.add_paragraph(doc.id, None, ParagraphInfo(text='test1', title_level=1, num=1, num_level=1))
#     p2 = doc_db.add_paragraph(doc.id, p1.id, ParagraphInfo(text='test2', title_level=2, num=1, num_level=2))
#     p3 = doc_db.add_paragraph(doc.id, p2.id, ParagraphInfo(text='test3', title_level=3, num=1, num_level=3))
#     doc_db.add_paragraph_ref_link(TParagraphRefLink(parent_id=p1.id, child_id=p3.id))