Spaces:
Runtime error
Runtime error
| import wikipedia | |
| class KB(): | |
| def __init__(self): | |
| self.entities = {} # { entity_title: {...} } | |
| self.relations = [] # [ head: entity_title, type: ..., tail: entity_title, | |
| # meta: { article_url: { spans: [...] } } ] | |
| self.sources = {} # { article_url: {...} } | |
| def merge_with_kb(self, kb2): | |
| for r in kb2.relations: | |
| article_url = list(r["meta"].keys())[0] | |
| source_data = kb2.sources[article_url] | |
| self.add_relation(r, source_data["article_title"], | |
| source_data["article_publish_date"]) | |
| def are_relations_equal(self, r1, r2): | |
| return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"]) | |
| def exists_relation(self, r1): | |
| return any(self.are_relations_equal(r1, r2) for r2 in self.relations) | |
| def merge_relations(self, r2): | |
| r1 = [r for r in self.relations | |
| if self.are_relations_equal(r2, r)][0] | |
| # if different article | |
| article_url = list(r2["meta"].keys())[0] | |
| if article_url not in r1["meta"]: | |
| r1["meta"][article_url] = r2["meta"][article_url] | |
| # if existing article | |
| else: | |
| spans_to_add = [span for span in r2["meta"][article_url]["spans"] | |
| if span not in r1["meta"][article_url]["spans"]] | |
| r1["meta"][article_url]["spans"] += spans_to_add | |
| def get_wikipedia_data(self, candidate_entity): | |
| try: | |
| page = wikipedia.page(candidate_entity, auto_suggest=False) | |
| return {"title": page.title, "url": page.url, "summary": page.summary} | |
| except Exception: | |
| return None | |
| def add_entity(self, e): | |
| self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"} | |
| def add_relation(self, r, article_title, article_publish_date): | |
| # check on wikipedia | |
| candidate_entities = [r["head"], r["tail"]] | |
| entities = [self.get_wikipedia_data(ent) for ent in candidate_entities] | |
| # if one entity does not exist, stop | |
| if any(ent is None for ent in entities): | |
| return | |
| # manage new entities | |
| for e in entities: | |
| self.add_entity(e) | |
| # rename relation entities with their wikipedia titles | |
| r["head"] = entities[0]["title"] | |
| r["tail"] = entities[1]["title"] | |
| # add source if not in kb | |
| article_url = list(r["meta"].keys())[0] | |
| if article_url not in self.sources: | |
| self.sources[article_url] = { | |
| "article_title": article_title, | |
| "article_publish_date": article_publish_date | |
| } | |
| # manage new relation | |
| if not self.exists_relation(r): | |
| self.relations.append(r) | |
| else: | |
| self.merge_relations(r) | |
| def get_textual_representation(self): | |
| res = "" + "### Entities\n" | |
| for e in self.entities.items(): | |
| # shorten summary | |
| e_temp = e[0], { | |
| k: f"{v[:100]}..." if k == "summary" else v | |
| for k, v in e[1].items() | |
| } | |
| res += f"- {e_temp}\n" | |
| res += "\n" | |
| res += "### Relations\n" | |
| for r in self.relations: | |
| res += f"- {r}\n" | |
| res += "\n" | |
| res += "### Sources\n" | |
| for s in self.sources.items(): | |
| res += f"- {s}\n" | |
| return res |