| | |
| | | self.session.commit() |
| | | return paragraph_entity_link.id |
| | | |
| | | def get_entity(self, entity): |
| | | ret = self.session.query(TEntity).where( |
| | | TEntity.name == entity.name and TEntity.type == entity.type and TEntity.doc_type == entity.doc_type).first() |
| | | if ret: |
| | | return ret |
| | | |
| | | def add_entity(self, entity): |
| | | """ |
| | | 添加实体 |
| | |
| | | def get_docs(self) -> list[TDoc]: |
| | | return self.session.query(TDoc).all() |
| | | |
| | | def get_text_with_entities(self, entity_names: list[str]) -> str: |
| | | def get_texts_with_entities(self, entity_names: list[str]): |
| | | """ |
| | | 根据实体词获取文本内容 |
| | | 根据实体词获取文本内容列表 |
| | | :param entity_names: list[str] - 实体词 |
| | | :return: str - 文本 |
| | | :return: list[str] - 文本列表 |
| | | """ |
| | | if not entity_names: |
| | | return "" |
| | |
| | | _entitie_ids = [entity.id for entity in _entities] |
| | | links = self.session.query(TParagraphEntityLink).where(TParagraphEntityLink.entity_id.in_(_entitie_ids)).all() |
| | | _paragraphs = [link.paragraph for link in links] |
| | | return [self.get_paragraph_full_text(p) for p in _paragraphs] |
| | | def get_text_with_entities(self, entity_names: list[str]) -> str: |
| | | """ |
| | | 根据实体词获取文本内容 |
| | | :param entity_names: list[str] - 实体词 |
| | | :return: str - 文本 |
| | | """ |
| | | texts = self.get_texts_with_entities(entity_names) |
| | | return '\n'.join(texts) |
| | | |
| | | return '\n'.join([self.get_paragraph_full_text(p) for p in _paragraphs]) |
| | | def get_entities_by_names(self, names: list[str]): |
| | | _entities = self.session.query(TEntity).where(TEntity.name.in_(names)).all() |
| | | return _entities |
| | | |
| | | def get_paragraph_full_text(self, p: TParagraph): |
| | | result = p.text if p.title_level == 0 else p.title_num + ' ' + p.text |
| | | return result + '\n' + '\n'.join([self.get_paragraph_full_text(p) for p in p.children]) |
| | | |
| | | def get_entities_by_doc_type(self, doc_type): |
| | | _entities = self.session.query(TEntity).where(TEntity.doc_type == doc_type).all() |
| | | return _entities |
| | | |
| | | def commit(self): |
| | | self.session.commit() |
| | | |
| | | |
| | | doc_dbh = DocDbHelper() |
| | | |
| | | # if __name__ == '__main__': |
| | | # text = doc_dbh.get_text_with_entities(['遥控包格式']) |
| | | # print(text) |
| | | # doc_db = DocDbHelper() |
| | | # # doc_db.insert_entities() |
| | | # doc = doc_db.add_doc(DocInfo(file='aaa', file_name='test')) |
| | | # p1 = doc_db.add_paragraph(doc.id, None, ParagraphInfo(text='test1', title_level=1, num=1, num_level=1)) |
| | | # p2 = doc_db.add_paragraph(doc.id, p1.id, ParagraphInfo(text='test2', title_level=2, num=1, num_level=2)) |
| | | # p3 = doc_db.add_paragraph(doc.id, p2.id, ParagraphInfo(text='test3', title_level=3, num=1, num_level=3)) |
| | | # doc_db.add_paragraph_ref_link(TParagraphRefLink(parent_id=p1.id, child_id=p3.id)) |