Spaces:
Running
Running
| import gradio as gr | |
| from transformers import GPT2LMHeadModel | |
| from indobenchmark import IndoNLGTokenizer | |
| gpt_tokenizer = IndoNLGTokenizer.from_pretrained("indobenchmark/indogpt") | |
| gpt_tokenizer.pad_token = gpt_tokenizer.eos_token | |
| kancilgpt = GPT2LMHeadModel.from_pretrained("abdiharyadi/kancilgpt") | |
| def generate_story(): | |
| stop = False | |
| prompt = "<s> awal cerita | judul:" | |
| judul = "" | |
| isi = "" | |
| end_part = "" | |
| isi_not_checked = True | |
| yield "..." | |
| while not stop: | |
| prompt_stop = False | |
| while not prompt_stop: | |
| gpt_input = gpt_tokenizer(prompt, return_tensors='pt') | |
| gpt_out = kancilgpt.generate( | |
| **gpt_input, | |
| do_sample=True, | |
| max_new_tokens=2, | |
| pad_token_id=gpt_tokenizer.eos_token_id, | |
| eos_token_id=gpt_tokenizer.eos_token_id | |
| ) | |
| gpt_out = gpt_out[0] | |
| result = gpt_tokenizer.decode(gpt_out) | |
| splitted_result = result.split(" | ") | |
| if len(splitted_result) <= 2: | |
| _, judul_prompt = splitted_result | |
| _, *judul_words = judul_prompt.split() | |
| judul = " ".join(judul_words) | |
| yield judul + "..." | |
| if "." in judul: | |
| print("Invalid judul!") | |
| prompt = "<s> awal cerita | judul:" | |
| continue | |
| isi = "" | |
| end_part = "" | |
| if gpt_out[-1] == gpt_tokenizer.eos_token_id: | |
| continue | |
| else: | |
| _, judul_prompt, isi, *end_part = splitted_result | |
| end_part = "".join(end_part) | |
| _, *judul_words = judul_prompt.split() | |
| judul = " ".join(judul_words) | |
| yield judul + "\n" + ("-" * len(judul)) + "\n" + isi + f"..." | |
| if len(splitted_result) == 3: | |
| if gpt_out[-1] == gpt_tokenizer.eos_token_id: | |
| continue | |
| elif isi_not_checked: | |
| quote_count = 0 | |
| prev_i = 0 | |
| for i, c in enumerate(isi): | |
| if c == "\"": | |
| quote_count += 1 | |
| prev_i = i | |
| if quote_count % 2 != 0: | |
| print("Invalid isi!") | |
| trimmed_isi = isi[:prev_i].rstrip() | |
| prompt = f"<s> awal cerita | judul: {judul} | {trimmed_isi}" | |
| continue | |
| isi_not_checked = False | |
| if gpt_out[-1] == gpt_tokenizer.eos_token_id: | |
| prompt_stop = True | |
| else: | |
| prompt = result | |
| # prompt_stop | |
| if (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])): | |
| print("Invalid ending! Regenerating ....") | |
| prompt = f"<s> awal cerita | judul: {judul} | {isi} |" | |
| continue | |
| stop = True | |
| total_isi = isi | |
| print("We skip the rest of the part for debug.") | |
| # TODO: Solve this. | |
| # ellipsis = "..." | |
| # while not end_part.startswith("tamat"): | |
| # yield judul + "\n" + ("-" * len(judul)) + "\n" + total_isi + f" {ellipsis}" | |
| # ellipsis += "." | |
| # i = 0 | |
| # in_quote = False | |
| # end_sentence = False | |
| # limit = 1750 | |
| # while i < len(isi) and not (end_sentence and (not in_quote) and isi[i] == " " and (len(isi) - i) < limit): | |
| # if isi[i] == "\"": | |
| # in_quote = not in_quote | |
| # if end_sentence: | |
| # end_sentence = isi[i] not in "abcdefghijklmnopqrstuvwxyz" | |
| # else: | |
| # end_sentence = isi[i] in ".?!" | |
| # i += 1 | |
| # # i == len(isi) or end_sentence or (not in_quote) or isi[i] == " " | |
| # while i < len(isi) and not (isi[i] in "abcdefghijklmnopqrstuvwxyz\""): | |
| # i += 1 | |
| # # i == len(isi) or isi[i] in "abcdefghijklmnopqrstuvwxyz\"" | |
| # if i == len(isi): | |
| # raise ValueError("What???") | |
| # next_isi = isi[i:] | |
| # stop = False | |
| # while not stop: | |
| # gpt_input = gpt_tokenizer(f'<s> pertengahan cerita | judul: {judul} | {next_isi}', return_tensors='pt') | |
| # gpt_out = kancilgpt.generate(**gpt_input, do_sample=True, max_length=512, pad_token_id=gpt_tokenizer.eos_token_id) | |
| # result = gpt_tokenizer.decode(gpt_out[0]) | |
| # _, judul_prompt, isi, *end_part = result.split(" | ") | |
| # end_part = "".join(end_part) | |
| # _, *judul_words = judul_prompt.split() | |
| # judul = " ".join(judul_words) | |
| # if isi[len(next_isi) + 1:].strip() != "": | |
| # print(isi[len(next_isi) + 1:]) | |
| # if "</s>" in isi or "|" in isi or (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])): | |
| # print("Invalid output! Regenerating ....") | |
| # continue | |
| # quote_count = 0 | |
| # for c in isi: | |
| # if c == "\"": | |
| # quote_count += 1 | |
| # if quote_count % 2 != 0: | |
| # print("Invalid output! Regenerating ....") | |
| # continue | |
| # stop = True | |
| # total_isi += " " + isi[len(next_isi) + 1:] | |
| # ellipsis = "..." | |
| yield judul + "\n" + ("-" * len(judul)) + "\n" + total_isi + "\n\ntamat." | |
| demo = gr.Interface( | |
| fn=generate_story, | |
| inputs=None, | |
| outputs=[ | |
| gr.Textbox(label="cerita", lines=7) | |
| ] | |
| ) | |
| demo.launch() | |