Spaces:
Runtime error
Runtime error
| import contextlib | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import utils | |
| from kb import KB | |
| import wikipedia | |
| MAX_TOPICS= 5 | |
| BUTTON_COLUMS = 4 | |
| st.header("Extracting a Knowledge Graph from text") | |
| # Loading the model | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") | |
| return tokenizer, model | |
| def generate_kb(): | |
| st_model_load = st.text('Loading NER model... It may take a while.') | |
| tokenizer, model = load_model() | |
| st.success('Model loaded!') | |
| st_model_load.text("") | |
| kb = utils.from_text_to_kb(' '.join(st.session_state['wiki_text']), model, tokenizer, "", verbose=True) | |
| utils.save_network_html(kb, filename="networks/network.html") | |
| st.session_state.kb_chart = "networks/network.html" | |
| st.session_state.kb_text = kb.get_textual_representation() | |
| st.session_state.error_url = None | |
| def show_textbox(): | |
| if len(st.session_state['wiki_text']) != 0: | |
| for i, t in enumerate(st.session_state['wiki_text']): | |
| new_expander = st.expander(label=f"{t[:30]}...", expanded=(i==0)) | |
| with new_expander: | |
| st.markdown(t) | |
| def wiki_show_text(page_title): | |
| with st.spinner(text="Fetching wiki page..."): | |
| # print(st.session_state['wiki_suggestions']) | |
| try: | |
| page = wikipedia.page(title=page_title, auto_suggest=False) | |
| st.session_state['wiki_text'].append(page.summary) | |
| st.session_state['topics'].append(page_title.lower()) | |
| st.session_state['wiki_suggestions'].remove(page_title) | |
| show_textbox() | |
| except wikipedia.DisambiguationError as e: | |
| with st.spinner(text="Woops, ambigious term, recalculating options..."): | |
| st.session_state['wiki_suggestions'].remove(page_title) | |
| temp = st.session_state['wiki_suggestions'] + e.options[:3] | |
| st.session_state['wiki_suggestions'] = list(set(temp)) | |
| show_textbox() | |
| except wikipedia.WikipediaException: | |
| st.session_state['wiki_suggestions'].remove(page_title) | |
| def wiki_add_text(term): | |
| if len(st.session_state['wiki_text']) > MAX_TOPICS: | |
| return | |
| try: | |
| page = wikipedia.page(title=term, auto_suggest=False) | |
| extra_text = page.summary | |
| st.session_state['wiki_text'].append(extra_text) | |
| st.session_state['topics'].append(term.lower()) | |
| st.session_state['nodes'].remove(term) | |
| except wikipedia.DisambiguationError as e: | |
| with st.spinner(text="Woops, ambigious term, recalculating options..."): | |
| st.session_state['nodes'].remove(term) | |
| temp = st.session_state['nodes'] + e.options[:3] | |
| st.session_state['nodes'] = list(set(temp)) | |
| except wikipedia.WikipediaException as e: | |
| st.session_state['nodes'].remove(term) | |
| def reset_thread(): | |
| st.session_state['wiki_text'] = [] | |
| st.session_state['topics'] = [] | |
| st.session_state['nodes'] = [] | |
| st.session_state['has_run_wiki'] = False | |
| st.session_state['wiki_suggestions'] = [] | |
| st.session_state['html_wiki'] = '' | |
| def show_wiki_hub_page(): | |
| cols = st.columns([7, 1]) | |
| b_cols = st.columns([2, 1.2, 8]) | |
| with cols[0]: | |
| st.text_input("Search", on_change=wiki_show_suggestion, key="text", value="graphs, are, awesome") | |
| with cols[1]: | |
| st.text('') | |
| st.text('') | |
| st.button("Search", on_click=wiki_show_suggestion, key="show_suggestion_key") | |
| with b_cols[0]: | |
| st.button("Generate KB", on_click=generate_kb) | |
| with b_cols[1]: | |
| st.button("Reset", on_click=reset_thread) | |
| def wiki_show_suggestion(): | |
| with st.spinner(text="Fetching wiki topics..."): | |
| text = st.session_state.text | |
| if (text is not None) and (text != ""): | |
| subjects = text.split(",")[:MAX_TOPICS] | |
| for subj in subjects: | |
| st.session_state['wiki_suggestions'] += wikipedia.search(subj, results = 3) | |
| show_wiki_suggestions_buttons() | |
| def show_wiki_suggestions_buttons(): | |
| if len(st.session_state['wiki_suggestions']) == 0: | |
| return | |
| num_buttons = len(st.session_state['wiki_suggestions']) | |
| # st.session_state['wiki_suggestions'] = list(set(st.session_state['wiki_suggestions'])) | |
| num_cols = num_buttons if 0 < num_buttons < BUTTON_COLUMS else BUTTON_COLUMS | |
| columns = st.columns([1] * num_cols ) | |
| for q in range(1 + num_buttons//num_cols): | |
| for i, (c, s) in enumerate(zip(columns, st.session_state['wiki_suggestions'][q*num_cols: (q+1)*num_cols])): | |
| with c: | |
| with contextlib.suppress(Exception): | |
| st.button(s, on_click=wiki_show_text, args=(s,), key=str(i)+s+"wiki_suggestion") | |
| def init_variables(): | |
| if 'wiki_suggestions' not in st.session_state: | |
| st.session_state['wiki_text'] = [] | |
| st.session_state['topics'] = [] | |
| st.session_state['nodes'] = [] | |
| st.session_state['has_run_wiki'] = True | |
| st.session_state['wiki_suggestions'] = [] | |
| st.session_state['html_wiki'] = '' | |
| init_variables() | |
| show_wiki_hub_page() | |
| # kb chart session state | |
| if 'kb_chart' not in st.session_state: | |
| st.session_state.kb_chart = None | |
| if 'kb_text' not in st.session_state: | |
| st.session_state.kb_text = None | |
| if 'error_url' not in st.session_state: | |
| st.session_state.error_url = None | |
| # show graph | |
| if st.session_state.error_url: | |
| st.markdown(st.session_state.error_url) | |
| elif st.session_state.kb_chart: | |
| with st.container(): | |
| st.subheader("Generated KB") | |
| st.markdown("*You can interact with the graph and zoom.*") | |
| html_source_code = open(st.session_state.kb_chart, 'r', encoding='utf-8').read() | |
| components.html(html_source_code, width=700, height=700) | |
| st.markdown(st.session_state.kb_text) |