Rudran commited on
Commit
9d6426d
·
verified ·
1 Parent(s): d478b00

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+
5
+ # Set page configuration
6
+ st.set_page_config(
7
+ page_title="Apertus-8B Chat",
8
+ page_icon="🤖",
9
+ layout="wide"
10
+ )
11
+
12
+ # Add a title to the app
13
+ st.title("🤖 Chat with Apertus-8B-Instruct")
14
+ st.caption("A Streamlit app running swiss-ai/Apertus-8B-Instruct-2509")
15
+
16
+ # --- MODEL LOADING ---
17
+ @st.cache_resource
18
+ def load_model():
19
+ """Loads the model and tokenizer with 4-bit quantization."""
20
+ model_id = "swiss-ai/Apertus-8B-Instruct-2509"
21
+
22
+ # Configure quantization to reduce memory usage
23
+ bnb_config = BitsAndBytesConfig(
24
+ load_in_4bit=True,
25
+ bnb_4bit_quant_type="nf4",
26
+ bnb_4bit_compute_dtype=torch.bfloat16
27
+ )
28
+
29
+ # Load the tokenizer
30
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
31
+
32
+ # Load the model
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_id,
35
+ quantization_config=bnb_config,
36
+ device_map="auto", # Automatically maps model layers to available hardware (GPU/CPU)
37
+ )
38
+ return tokenizer, model
39
+
40
+ # Load the model and display a spinner while doing so
41
+ with st.spinner("Loading Apertus-8B model... This might take a moment."):
42
+ tokenizer, model = load_model()
43
+
44
+ # --- CHAT INTERFACE ---
45
+ # Initialize chat history
46
+ if "messages" not in st.session_state:
47
+ st.session_state.messages = []
48
+
49
+ # Display chat messages from history on app rerun
50
+ for message in st.session_state.messages:
51
+ with st.chat_message(message["role"]):
52
+ st.markdown(message["content"])
53
+
54
+ # Accept user input
55
+ if prompt := st.chat_input("What would you like to ask?"):
56
+ # Add user message to chat history
57
+ st.session_state.messages.append({"role": "user", "content": prompt})
58
+ # Display user message in chat message container
59
+ with st.chat_message("user"):
60
+ st.markdown(prompt)
61
+
62
+ # --- GENERATION ---
63
+ with st.chat_message("assistant"):
64
+ with st.spinner("Thinking..."):
65
+ # Prepare the input for the model
66
+ input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
67
+
68
+ # Generate a response
69
+ outputs = model.generate(
70
+ **input_ids,
71
+ max_new_tokens=256,
72
+ do_sample=True,
73
+ temperature=0.7,
74
+ top_k=50,
75
+ top_p=0.95
76
+ )
77
+
78
+ # Decode and display the response
79
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
+ # The model often repeats the prompt, so we can clean it up
81
+ cleaned_response = response.replace(prompt, "").strip()
82
+
83
+ st.markdown(cleaned_response)
84
+
85
+ # Add assistant response to chat history
86
+ st.session_state.messages.append({"role": "assistant", "content": cleaned_response})