KYO30 commited on
Commit
67bb651
ยท
verified ยท
1 Parent(s): 88adda9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -4,26 +4,30 @@ import torch
4
  from threading import Thread
5
 
6
  # --- 1. ๋ชจ๋ธ ๋กœ๋“œ (Space์˜ GPU ํ™œ์šฉ) ---
7
- # ์š”์ฒญํ•˜์‹  ๋ชจ๋ธ ์ด๋ฆ„์ž…๋‹ˆ๋‹ค.
8
- MODEL_NAME = "kakaocorp/kanana-1.5-2.1b-instruct-2505"
9
 
10
  print(f"๋ชจ๋ธ์„ ๋กœ๋”ฉ ์ค‘์ž…๋‹ˆ๋‹ค: {MODEL_NAME} (Space GPU ์‚ฌ์šฉ)")
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_NAME,
14
- dtype=torch.float16, # ๐Ÿ’ฅ ์ˆ˜์ •: 'torch_dtype' ๋Œ€์‹  'dtype' ์‚ฌ์šฉ
15
- device_map="auto"
16
  )
17
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
18
 
19
  # --- 2. ์ฑ—๋ด‡ ์‘๋‹ต ํ•จ์ˆ˜ (Gradio๊ฐ€ ์ด ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœ) ---
 
 
20
  def predict(message, history):
21
 
22
  # Kanana์˜ ํ”„๋กฌํ”„ํŠธ ํ˜•์‹: <bos>user\n{prompt}\n<eos>assistant\n
23
  history_prompt = ""
 
24
  for user_msg, assistant_msg in history:
25
  history_prompt += f"<bos>user\n{user_msg}\n<eos>assistant\n{assistant_msg}\n"
26
 
 
27
  final_prompt = history_prompt + f"<bos>user\n{message}\n<eos>assistant\n"
28
 
29
  inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
@@ -31,35 +35,37 @@ def predict(message, history):
31
  # --- ์‹ค์‹œ๊ฐ„ ํƒ€์ดํ•‘ ํšจ๊ณผ(์ŠคํŠธ๋ฆฌ๋ฐ)๋ฅผ ์œ„ํ•œ ์„ค์ • ---
32
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
33
 
 
34
  generation_kwargs = dict(
35
- inputs,
36
  streamer=streamer,
37
- max_new_tokens=1024,
38
  eos_token_id=tokenizer.eos_token_id,
39
  pad_token_id=tokenizer.pad_token_id,
40
- temperature=0.7,
41
- do_sample=True
42
  )
43
 
44
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
45
  thread.start()
46
 
 
47
  generated_text = ""
48
  for new_text in streamer:
49
  generated_text += new_text
50
- yield generated_text
51
 
52
  # --- 3. Gradio ์ฑ—๋ด‡ UI ์ƒ์„ฑ ---
53
- # ๐Ÿ’ฅ ์ˆ˜์ •: ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•œ 'retry_btn'๊ณผ 'undo_btn' ์ธ์ž๋ฅผ ์ œ๊ฑฐํ–ˆ์Šต๋‹ˆ๋‹ค.
54
  chatbot_ui = gr.ChatInterface(
55
  fn=predict, # ์ฑ—๋ด‡์ด ์‚ฌ์šฉํ•  ํ•จ์ˆ˜
56
  title="Kanana 1.5 ์ฑ—๋ด‡ ํ…Œ์ŠคํŠธ ๐Ÿค–",
57
  description=f"{MODEL_NAME} ๋ชจ๋ธ์„ ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค.",
58
- theme="soft",
59
- examples=[["ํ•œ๊ตญ์˜ ์ˆ˜๋„๋Š” ์–ด๋””์•ผ?"], ["AI์— ๋Œ€ํ•ด 3์ค„๋กœ ์š”์•ฝํ•ด์ค˜."]],
60
- # retry_btn=None, <-- ์ด ๋ถ€๋ถ„์ด ์˜ค๋ฅ˜ ์›์ธ (์ œ๊ฑฐ)
61
- # undo_btn="์ด์ „ ๋Œ€ํ™” ์‚ญ์ œ", <-- ์ด ๋ถ€๋ถ„๋„ ์ตœ์‹  ๋ฒ„์ „์—์„  ์ด๋ฆ„์ด ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์–ด ์ œ๊ฑฐ
62
- clear_btn="์ „์ฒด ๋Œ€ํ™” ์ดˆ๊ธฐํ™”" # 'clear_btn'์€ ์•„์ง ์œ ํšจํ•ฉ๋‹ˆ๋‹ค.
63
  )
64
 
65
- # ---
 
 
 
4
  from threading import Thread
5
 
6
  # --- 1. ๋ชจ๋ธ ๋กœ๋“œ (Space์˜ GPU ํ™œ์šฉ) ---
7
+ # 2505 ๋ชจ๋ธ์€ ์•„์ง ์กด์žฌํ•˜์ง€ ์•Š์•„, ํ˜„์žฌ ์ตœ์‹  ๋ชจ๋ธ์ธ 2405๋กœ ์ˆ˜์ •ํ–ˆ์Šต๋‹ˆ๋‹ค.
8
+ MODEL_NAME = "kakaocorp/kanana-1.5-2.1b-instruct-2405"
9
 
10
  print(f"๋ชจ๋ธ์„ ๋กœ๋”ฉ ์ค‘์ž…๋‹ˆ๋‹ค: {MODEL_NAME} (Space GPU ์‚ฌ์šฉ)")
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_NAME,
14
+ torch_dtype=torch.float16, # ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ์„ ์œ„ํ•ด 16๋น„ํŠธ ์‚ฌ์šฉ
15
+ device_map="auto" # ์ค‘์š”: ์•Œ์•„์„œ GPU์— ํ• ๋‹น
16
  )
17
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
18
 
19
  # --- 2. ์ฑ—๋ด‡ ์‘๋‹ต ํ•จ์ˆ˜ (Gradio๊ฐ€ ์ด ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœ) ---
20
+ # message: ์‚ฌ์šฉ์ž๊ฐ€ ์ž…๋ ฅํ•œ ๋ฉ”์‹œ์ง€
21
+ # history: ์ด์ „ ๋Œ€ํ™” ๊ธฐ๋ก (Gradio๊ฐ€ ์ž๋™์œผ๋กœ ๊ด€๋ฆฌ)
22
  def predict(message, history):
23
 
24
  # Kanana์˜ ํ”„๋กฌํ”„ํŠธ ํ˜•์‹: <bos>user\n{prompt}\n<eos>assistant\n
25
  history_prompt = ""
26
+ # ์ด์ „ ๋Œ€ํ™” ๊ธฐ๋ก(history)์„ Kanana ํ”„๋กฌํ”„ํŠธ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
27
  for user_msg, assistant_msg in history:
28
  history_prompt += f"<bos>user\n{user_msg}\n<eos>assistant\n{assistant_msg}\n"
29
 
30
+ # ํ˜„์žฌ ๋ฉ”์‹œ์ง€๋ฅผ ํ”„๋กฌํ”„ํŠธ์— ์ถ”๊ฐ€
31
  final_prompt = history_prompt + f"<bos>user\n{message}\n<eos>assistant\n"
32
 
33
  inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
 
35
  # --- ์‹ค์‹œ๊ฐ„ ํƒ€์ดํ•‘ ํšจ๊ณผ(์ŠคํŠธ๋ฆฌ๋ฐ)๋ฅผ ์œ„ํ•œ ์„ค์ • ---
36
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
37
 
38
+ # ๋ชจ๋ธ ์ƒ์„ฑ(generate) ์ž‘์—…์„ ๋ณ„๋„ ์Šค๋ ˆ๋“œ์—์„œ ์‹คํ–‰
39
  generation_kwargs = dict(
40
+ **inputs, # inputs ๋”•์…”๋„ˆ๋ฆฌ์˜ ๋ชจ๋“  ํ‚ค-๊ฐ’ ์Œ์„ ์ธ์ž๋กœ ์ „๋‹ฌ
41
  streamer=streamer,
42
+ max_new_tokens=1024, # ์ตœ๋Œ€ ์ƒ์„ฑ ํ† ํฐ ์ˆ˜
43
  eos_token_id=tokenizer.eos_token_id,
44
  pad_token_id=tokenizer.pad_token_id,
45
+ temperature=0.7, # ์ฐฝ์˜์„ฑ ์กฐ์ ˆ
46
+ do_sample=True # ์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ
47
  )
48
 
49
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
50
  thread.start()
51
 
52
+ # ์ŠคํŠธ๋ฆฌ๋จธ์—์„œ ๋‚˜์˜ค๋Š” ํ…์ŠคํŠธ๋ฅผ ๋ฐ”๋กœ๋ฐ”๋กœ ๋ฐ˜ํ™˜ (yield)
53
  generated_text = ""
54
  for new_text in streamer:
55
  generated_text += new_text
56
+ yield generated_text # ํ…์ŠคํŠธ๋ฅผ ํ•œ ๊ธ€์ž์”ฉ ์‹ค์‹œ๊ฐ„์œผ๋กœ ๋ณด๋ƒ„
57
 
58
  # --- 3. Gradio ์ฑ—๋ด‡ UI ์ƒ์„ฑ ---
59
+ # gr.ChatInterface๋ฅผ ์“ฐ๋ฉด UI๊ฐ€ ์ฑ—๋ด‡ ํ˜•ํƒœ๋กœ ์ž๋™ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.
60
  chatbot_ui = gr.ChatInterface(
61
  fn=predict, # ์ฑ—๋ด‡์ด ์‚ฌ์šฉํ•  ํ•จ์ˆ˜
62
  title="Kanana 1.5 ์ฑ—๋ด‡ ํ…Œ์ŠคํŠธ ๐Ÿค–",
63
  description=f"{MODEL_NAME} ๋ชจ๋ธ์„ ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค.",
64
+ theme="soft", # ํ…Œ๋งˆ ์„ค์ •
65
+ examples=[["ํ•œ๊ตญ์˜ ์ˆ˜๋„๋Š” ์–ด๋””์•ผ?"], ["AI์— ๋Œ€ํ•ด 3์ค„๋กœ ์š”์•ฝํ•ด์ค˜."]]
66
+ # retry_btn, undo_btn, clear_btn ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” ํ˜„์žฌ Gradio ๋ฒ„์ „์—์„œ ์ง€์›๋˜์ง€ ์•Š์•„ ์‚ญ์ œํ–ˆ์Šต๋‹ˆ๋‹ค.
 
 
67
  )
68
 
69
+ # --- 4. ์•ฑ ์‹คํ–‰ ---
70
+ # .launch()๋กœ Space์—์„œ ์•ฑ์„ ์‹คํ–‰์‹œํ‚ต๋‹ˆ๋‹ค.
71
+ chatbot_ui.launch()