Update app.py
Browse files
app.py
CHANGED
|
@@ -230,19 +230,19 @@ def generate(
|
|
| 230 |
|
| 231 |
# Construct the input prompt using the functions from the construct_input_prompt function
|
| 232 |
input_prompt = construct_input_prompt(chat_history, message)
|
| 233 |
-
|
| 234 |
# Move the condition here after the assignment
|
| 235 |
if input_prompt:
|
| 236 |
conversation.append({"role": "system", "content": input_prompt})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
for user, assistant in chat_history:
|
| 238 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
| 239 |
conversation.append({"role": "user", "content": message})
|
| 240 |
|
| 241 |
-
|
| 242 |
-
# input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device)
|
| 243 |
-
# Original from HuggingFace Llama2 Chatbot: input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
|
| 244 |
-
|
| 245 |
-
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
|
| 246 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
| 247 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
| 248 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
|
|
|
| 230 |
|
| 231 |
# Construct the input prompt using the functions from the construct_input_prompt function
|
| 232 |
input_prompt = construct_input_prompt(chat_history, message)
|
| 233 |
+
|
| 234 |
# Move the condition here after the assignment
|
| 235 |
if input_prompt:
|
| 236 |
conversation.append({"role": "system", "content": input_prompt})
|
| 237 |
+
|
| 238 |
+
# Convert input prompt to tensor
|
| 239 |
+
input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device)
|
| 240 |
+
|
| 241 |
for user, assistant in chat_history:
|
| 242 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
| 243 |
conversation.append({"role": "user", "content": message})
|
| 244 |
|
| 245 |
+
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
| 247 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
| 248 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|