| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from unidecode import unidecode |
| from collections import Counter |
| import torch |
| import os |
| import gradio as gr |
| import numpy as np |
| import re |
| import string |
| from peft import PeftModel, PeftConfig |
|
|
| tokenizer = AutoTokenizer.from_pretrained("osiria/primo") |
| model = AutoModelForCausalLM.from_pretrained("osiria/primo") |
| model = PeftModel.from_pretrained(model, "osiria/primo") |
|
|
| class Prime: |
| |
| def __init__(self, tokenizer, model): |
| self.tokenizer = tokenizer |
| self.model = model |
| |
| def _check_sublist(self, lst, sub_lst, sep = " "): |
| |
| l_type = type(lst[0]) |
| lst = sep.join(list(map(str, lst))) |
| sub_lst = sep.join(list(map(str, sub_lst))) |
| |
| return sub_lst in lst |
| |
| def _exclude_sublist(self, lst, sub_lst, sep = " "): |
| |
| l_type = type(lst[0]) |
| lst = sep.join(list(map(str, lst))) |
| sub_lst = sep.join(list(map(str, sub_lst))) |
| lst = re.sub("\s+", " ", lst.replace(sub_lst, "")).strip().split(sep) |
| lst = list(map(l_type, lst)) |
| |
| return lst |
| |
| def generate(self, prompt, message = "", sep = " [AI]", max_tokens = 100, excluded = [[40, 19]], |
| lookback = 5, resample_tokens = [27793], replace_tokens = {11302: 23318}, |
| stop_tokens = [239], |
| sample = False, |
| top_k = 5): |
| |
| if message: |
| prompt = message + ". " + prompt |
| prompt = prompt.replace("β", '"').replace("β", '"').replace("β", "'") |
| if not sample: |
| top_k = 2 |
| tokens = tokenizer.encode("[HUMAN] " + prompt + sep) |
| tokens_generated = [] |
| checkpoint = 0 |
| while tokens[-1] not in stop_tokens and len(tokens_generated) < max_tokens: |
| output = model.forward(input_ids=torch.tensor([tokens]).to(device)).logits[0,-1] |
| output = torch.softmax(output, dim = 0) |
| candidates = torch.topk(output, k = top_k) |
| if sample: |
| indices = candidates.indices |
| scores = candidates.values |
| next_token = indices[torch.multinomial(scores, 1)[0].item()] |
| else: |
| next_token = candidates.indices[0] |
| next_token = next_token.item() |
| sub_tokens = tokens_generated[-lookback:] + [next_token] |
| if next_token in resample_tokens: |
| next_token = candidates.indices[1] |
| next_token = next_token.item() |
| if len(tokens_generated) >= (lookback + 1) and next_token in tokens_generated[-2:]: |
| next_token = candidates.indices[1] |
| next_token = next_token.item() |
| elif len(tokens_generated) >= lookback and self._check_sublist(tokens_generated, sub_tokens): |
| if checkpoint: |
| tokens = tokens[:checkpoint] |
| break |
| else: |
| next_token = candidates.indices[1] |
| next_token = next_token.item() |
| sample = True |
| if next_token in replace_tokens: |
| next_token = replace_tokens[next_token] |
| tokens = tokens + [next_token] |
| tokens_generated = tokens_generated + [next_token] |
| if next_token == 5: |
| checkpoint = len(tokens) |
| for ex_lst in excluded: |
| tokens = self._exclude_sublist(tokens, ex_lst) |
| output = tokenizer.decode(tokens, skip_special_tokens=True) |
| output = output.split(sep)[-1].strip() |
| output = output[0].upper() + output[1:] |
| if output[-1] == tokenizer.decode(stop_tokens[0]): |
| output = output[:-1] |
| if len(re.findall("\d\.", output)) > 1: |
| output = re.sub("\d\.", "<br>β’", output) |
| output = re.sub("^\<br\>", "", output) |
| return output |
|
|
| model.eval() |
| device = torch.device("cuda") |
| prime = Prime(tokenizer = tokenizer, model = model) |
|
|
| def process_input(user_input, max_tokens, sample, top_k, message): |
| return prime.generate(prompt = user_input, message = message, |
| max_tokens = max_tokens, sample = sample, |
| top_k = top_k) |
|
|
|
|
| header = '''-------------------------------------------------------------------------------------------------- |
| <style> |
| .vertical-text { |
| writing-mode: vertical-lr; |
| text-orientation: upright; |
| background-color:red; |
| } |
| </style> |
| <center> |
| <body> |
| <span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;">β</span> |
| <span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;">ββ</span> |
| <span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;">βββββ</span> |
| <span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;">βββββ</span> |
| <span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;">ββ</span> |
| <span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;">β</span> |
| </body> |
| </center> |
| <br> |
| <center><img src="file/primo.png" width="100"></center> |
| ''' |
|
|
| import gradio as gr |
| import random |
| import time |
|
|
| with gr.Blocks(title="primo", css="footer {visibility: hidden}", theme=gr.themes.Default(text_size="md", spacing_size="md")) as interface: |
| gr.Markdown(header) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("<b>opzioni</b>") |
| max_tokens = gr.Slider(1, 250, value=150, label="massimo numero di token", info="scegli un limite tra 1 e 250") |
| sample = gr.Checkbox(label="campionamento") |
| top_k = gr.Slider(1, 5, step=1, value=1, label="creativitΓ ", info="scegli un livello tra 1 e 5") |
| message = gr.Textbox(label="messaggio di sistema", value = "") |
| clear = gr.Button("pulisci conversazione") |
| with gr.Column(scale=8): |
| chatbot = gr.Chatbot(label = "prime").style(height=600) |
| msg = gr.Textbox(label = "richiesta") |
|
|
| def user(user_message, history): |
| return gr.update(value="", interactive=False), history + [[user_message, None]] |
|
|
| def bot(history, message, max_tokens, sample, top_k): |
| bot_message = process_input(history[-1][0], message = message, max_tokens = max_tokens, |
| sample = sample, top_k = top_k) |
| history[-1][1] = "" |
| for character in bot_message: |
| history[-1][1] += character |
| time.sleep(0.05) |
| yield history |
|
|
| response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
| bot, [chatbot, message, max_tokens, sample, top_k], chatbot |
| ) |
| response.then(lambda: gr.update(interactive=True), None, [msg], queue=False) |
| clear.click(lambda: None, None, chatbot, queue=False) |
| with gr.Column(scale=1): |
| gr.Markdown("<b>attenzione</b>") |
| gr.Markdown("il modello potrebbe comportarsi in maniera imprevista nel caso in cui riceva prompt troppo lontani dal suo pre-training o fine-tuning e, per via della natura probabilistica del meccanismo di generazione, potrebbe occasionalmente produrre contenuti distorti o offensivi in relazione a tematiche come il genere, le etnie, le ideologie, e le convinzioni politiche o religiose<br><br>per via di queste limitazioni, il modello e i suoi output dovrebbero essere usati con cautela, e non dovrebbero essere coinvolti in contesti che richiedono che il testo generato sia corretto o veritiero") |
|
|
| interface.queue() |
| interface.launch() |