botsi commited on
Commit
fd7d7b0
·
verified ·
1 Parent(s): a0fa364

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +298 -0
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original code from https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat and https://huggingface.co/spaces/radames/gradio-chatbot-read-query-param
2
+ import gradio as gr
3
+ import time
4
+ import random
5
+ import json
6
+ import mysql.connector
7
+ import os
8
+ import csv
9
+
10
+ from datetime import datetime
11
+ # from huggingface_hub import Repository, hf_hub_download
12
+
13
+ import spaces
14
+ import torch
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
16
+ from threading import Thread
17
+ from typing import Iterator
18
+
19
+ # data_fetcher.py
20
+ import mysql.connector
21
+ import urllib.parse
22
+ import urllib.request
23
+
24
+ # For Prompt Engineering
25
+ # import requests
26
+ # from huggingface_hub import AsyncInferenceClient
27
+
28
+ # Save chat history as JSON
29
+ import atexit
30
+
31
+ # Add this global variable to store the chat history
32
+ # global_chat_history = []
33
+ # Add this function to store the chat history
34
+ #def save_chat_history():
35
+ # """Save the chat history to a JSON file."""
36
+ # with open("chat_history.json", "w") as json_file:
37
+ # json.dump(global_chat_history, json_file)
38
+
39
+ #from huggingface_hub import login
40
+ #HF_TOKEN = os.getenv('HF_TOKEN')
41
+
42
+ MAX_MAX_NEW_TOKENS = 2048
43
+ DEFAULT_MAX_NEW_TOKENS = 1024
44
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
45
+
46
+ DESCRIPTION = """\
47
+ # Llama-2 7B Chat
48
+ This is your personal space to chat.
49
+ You can ask anything from strategic questions regarding the game or just chat as you like.
50
+ """
51
+ '''LICENSE = """
52
+ <p/>
53
+
54
+ ---
55
+ As a derivate work of [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta,
56
+ this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/USE_POLICY.md).
57
+ """
58
+ '''
59
+
60
+ if not torch.cuda.is_available():
61
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
62
+
63
+
64
+ if torch.cuda.is_available():
65
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
66
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
67
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
68
+ tokenizer.use_default_system_prompt = False
69
+
70
+
71
+ ## gradio-chatbot-read-query-param
72
+ get_window_session_index = """
73
+ function() {
74
+ const urlParams = new URLSearchParams(window.location.search);
75
+ const session_index = urlParams.get('session_index');
76
+ return session_index;
77
+ }
78
+ """
79
+
80
+ def fetch_personalized_data(session_index):
81
+ # Connect to the database
82
+ conn = mysql.connector.connect(
83
+ host="18.153.94.89",
84
+ user="root",
85
+ password="N12RXMKtKxRj",
86
+ database="lionessdb"
87
+ )
88
+
89
+ # Create a cursor object
90
+ cursor = conn.cursor()
91
+
92
+ # Replace the placeholders with your actual database and table names
93
+ core_table = "e5390g37096_core"
94
+ decisions_table = "e5390g37096_decisions"
95
+
96
+ # Query to fetch relevant data from both tables based on session_index
97
+ query = f"""
98
+ SELECT e5390g37096_core.playerNr,
99
+ e5390g37096_core.groupNr,
100
+ e5390g37096_core.subjectNr
101
+ FROM e5390g37096_core
102
+ JOIN e5390g37096_decisions ON
103
+ e5390g37096_core.playerNr = e5390g37096_decisions.playerNr
104
+ WHERE e5390g37096_decisions.session_index = '{session_index}'
105
+ """
106
+
107
+ try:
108
+ cursor.execute(query)
109
+
110
+ # Fetch all rows as lists of tuples
111
+ rows = cursor.fetchall()
112
+
113
+ # Close the database connection
114
+ conn.close()
115
+
116
+ # return [[str(row[0]), str(row[1]), str(row[2])] for row in rows] # Convert each row to a list
117
+ # Convert the rows to a list of dictionaries
118
+ data = [{'playerNr': row[0], 'groupNr': row[1], 'subjectNr': row[2]} for row in rows]
119
+ return data
120
+
121
+ except mysql.connector.Error as err:
122
+ print(f"Error: {err}")
123
+ return None
124
+
125
+
126
+ ## gradio-chatbot-read-query-param
127
+ def get_window_url_params():
128
+ return """
129
+ function() {
130
+ const params = new URLSearchParams(window.location.search);
131
+ const url_params = Object.fromEntries(params);
132
+ return url_params;
133
+ }
134
+ """
135
+
136
+ ## trust-game-llama-2-7b-chat
137
+ # app.py
138
+ def construct_input_prompt(chat_history, message):
139
+ input_prompt = f"<s>[INST] <<SYS>>\n{get_default_system_prompt()}\n<</SYS>>\n\n "
140
+
141
+ for user, assistant in chat_history:
142
+ input_prompt += f"{user} [/INST] {assistant} <s>[INST] "
143
+
144
+ input_prompt += f"{message} [/INST] "
145
+
146
+ return input_prompt
147
+
148
+ ## trust-game-llama-2-7b-chat
149
+ # app.py
150
+ @spaces.GPU
151
+ def generate(
152
+ message: str,
153
+ chat_history: list[tuple[str, str]],
154
+ # system_prompt: str,
155
+ max_new_tokens: int = 1024,
156
+ temperature: float = 0.6,
157
+ top_p: float = 0.9,
158
+ top_k: int = 50,
159
+ repetition_penalty: float = 1.2,
160
+ ) -> Iterator[str]: # Change return type hint to Iterator[str]
161
+
162
+ # Construct the input prompt using the functions from the system_prompt_config module
163
+ input_prompt = construct_input_prompt(chat_history, message)
164
+
165
+ # Use the global variable to store the chat history
166
+ # global global_chat_history
167
+
168
+ conversation = []
169
+
170
+ # Move the condition here after the assignment
171
+ if input_prompt:
172
+ conversation.append({"role": "system", "content": input_prompt})
173
+
174
+ # Convert input prompt to tensor
175
+ input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device)
176
+
177
+
178
+ for user, assistant in chat_history:
179
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
180
+ conversation.append({"role": "user", "content": message})
181
+
182
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
183
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
184
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
185
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
186
+ input_ids = input_ids.to(model.device)
187
+
188
+ # Set up the TextIteratorStreamer
189
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
190
+
191
+ # Set up the generation arguments
192
+ generate_kwargs = dict(
193
+ {"input_ids": input_ids},
194
+ streamer=streamer,
195
+ max_new_tokens=max_new_tokens,
196
+ do_sample=True,
197
+ top_p=top_p,
198
+ top_k=top_k,
199
+ temperature=temperature,
200
+ num_beams=1,
201
+ repetition_penalty=repetition_penalty,
202
+ )
203
+
204
+ # Start the model generation thread
205
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
206
+ t.start()
207
+
208
+ # Yield generated text chunks
209
+ outputs = []
210
+ for text in streamer:
211
+ outputs.append(text)
212
+ yield "".join(outputs)
213
+
214
+ # Update the global_chat_history with the current conversation
215
+ # global_chat_history.append({
216
+ # "message": message,
217
+ # "chat_history": chat_history,
218
+ # "system_prompt": input_prompt,
219
+ # "output": outputs[-1], # Assuming you want to save the latest model output
220
+ # })
221
+
222
+ # The modification above starting with "global_chat.history.append" introduces a global_chat_history variable to store the chat history globally.
223
+ # The save_chat_history function is registered to be called when the program exits
224
+ # using atexit.register(save_chat_history).
225
+ # It saves the chat history to a JSON file named "chat_history.json".
226
+ # The generate function is updated to append the current conversation to global_chat_history
227
+ # after generating each response.
228
+
229
+ chat_interface = gr.ChatInterface(
230
+ fn=generate,
231
+ theme="soft",
232
+ retry_btn=None,
233
+ clear_btn=None,
234
+ undo_btn=None,
235
+ chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width = False),
236
+ examples=[
237
+ ["Can you explain the rules very briefly again?"],
238
+ ["How much should I invest in order to win?"],
239
+ ["What happened in the last round?"],
240
+ ["What is my probability to win if I do not share anything?"],
241
+ ],
242
+ )
243
+
244
+ with gr.Blocks(css="style.css") as demo:
245
+ #gr.Markdown(DESCRIPTION)
246
+ #gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
247
+ ## gradio-chatbot-read-query-param
248
+ url_params = gr.JSON({}, visible=False, label="URL Params")
249
+
250
+ ## gradio-chatbot-read-query-param
251
+ def get_session_index(history, url_params):
252
+ if history and bool(history[-1][0].strip()):
253
+ session_index = url_params.get('session_index')
254
+ print(session_index)
255
+ # Fetch personalized data
256
+ personalized_data = fetch_personalized_data(session_index)
257
+ print(personalized_data)
258
+ return personalized_data
259
+
260
+ ## trust-game-llama-2-7b-chat
261
+ # app.py
262
+ def get_default_system_prompt(personalized_data):
263
+ #BOS, EOS = "<s>", "</s>"
264
+ #BINST, EINST = "[INST]", "[/INST]"
265
+ BSYS, ESYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
266
+
267
+ DEFAULT_SYSTEM_PROMPT = f""" You are an intelligent and fair game guide in a 2-player trust game, assisting players in making decisions to win.
268
+ Answer in a consistent style. Answer per question should be maximum 2 sentences long. The players are called The Investor and The Dealer and keep their role throughout the whole game.
269
+ Both start with 10€ in round 1. The game consists of 3 rounds. In round 1, The Investor invests between 0€ and 10€.
270
+ This amount is tripled automatically, and The Dealer can then distribute the tripled amount. After that, round 1 is over.
271
+ Both go into the next round with their current asset: The Investor with 10€ minus what he invested plus what he received back from The Dealer.
272
+ The Dealer with 10€ plus what he kept from the tripled amount.
273
+ You will receive a JSON with information on who trusted whom with how much money after each round as context.
274
+ Your goal is to guide players through the game, providing clear instructions and explanations.
275
+ If any question or action seems unclear, explain it rather than providing inaccurate information.
276
+ If you're unsure about an answer, it's better not to guess.
277
+
278
+ Example JSON context after a round: {personalized_data}
279
+
280
+ Few-shot training examples
281
+ {BSYS} Give an overview of the trust game. {ESYS}
282
+ {BSYS} Explain how trust amounts are calculated. {ESYS}
283
+ {BSYS} What happens if a player doesn't trust in a round? {ESYS}
284
+ """
285
+ print(DEFAULT_SYSTEM_PROMPT)
286
+ return DEFAULT_SYSTEM_PROMPT
287
+
288
+ chat_interface.render()
289
+ #gr.Markdown(LICENSE)
290
+
291
+ if __name__ == "__main__":
292
+ #demo.queue(max_size=20).launch()
293
+ demo.queue(max_size=20)
294
+ demo.launch(share=True, debug=True)
295
+
296
+ # Register the function to be called when the program exits
297
+ # atexit.register(save_chat_history)
298
+