botsi commited on
Commit
389922e
·
verified ·
1 Parent(s): 1eed4e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -230,19 +230,19 @@ def generate(
230
 
231
  # Construct the input prompt using the functions from the construct_input_prompt function
232
  input_prompt = construct_input_prompt(chat_history, message)
233
-
234
  # Move the condition here after the assignment
235
  if input_prompt:
236
  conversation.append({"role": "system", "content": input_prompt})
 
 
 
 
237
  for user, assistant in chat_history:
238
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
239
  conversation.append({"role": "user", "content": message})
240
 
241
- # Convert input prompt to tensor
242
- # input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device)
243
- # Original from HuggingFace Llama2 Chatbot: input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
244
-
245
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
246
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
247
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
248
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
230
 
231
  # Construct the input prompt using the functions from the construct_input_prompt function
232
  input_prompt = construct_input_prompt(chat_history, message)
233
+
234
  # Move the condition here after the assignment
235
  if input_prompt:
236
  conversation.append({"role": "system", "content": input_prompt})
237
+
238
+ # Convert input prompt to tensor
239
+ input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device)
240
+
241
  for user, assistant in chat_history:
242
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
243
  conversation.append({"role": "user", "content": message})
244
 
245
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
 
 
 
246
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
247
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
248
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")