Abdelrahmann121 commited on
Commit
d037f24
Β·
verified Β·
1 Parent(s): 741a243

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +950 -0
app.py ADDED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, Wav2Vec2ForCTC, Wav2Vec2Processor
3
+ from sentence_transformers import SentenceTransformer
4
+ import numpy as np
5
+ import random
6
+ import faiss
7
+ import json
8
+ import logging
9
+ import re
10
+ import streamlit as st
11
+ from datetime import datetime
12
+ import os
13
+ import torch
14
+ import librosa
15
+ from gtts import gTTS
16
+ import tempfile
17
+ import io
18
+ import base64
19
+ import time
20
+ from audio_recorder_streamlit import audio_recorder
21
+
22
+ # Set up logging
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # ============================
27
+ # AUDIO PROCESSING UTILITIES
28
+ # ============================
29
+
30
+ class AudioProcessor:
31
+ def __init__(self):
32
+ """Initialize audio processing components"""
33
+ try:
34
+ # Load Wav2Vec2 model for speech-to-text
35
+ self.stt_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
36
+ self.stt_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
37
+ logger.info("βœ… STT model loaded successfully")
38
+ except Exception as e:
39
+ logger.error(f"❌ Error loading STT model: {e}")
40
+ self.stt_processor = None
41
+ self.stt_model = None
42
+
43
+ def speech_to_text_from_bytes(self, audio_bytes):
44
+ """Convert speech to text from audio bytes"""
45
+ if not self.stt_processor or not self.stt_model:
46
+ return "STT model not available"
47
+
48
+ try:
49
+ # Create temporary file from bytes
50
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
51
+ tmp_file.write(audio_bytes)
52
+ tmp_file_path = tmp_file.name
53
+
54
+ # Load and preprocess audio
55
+ audio_input, sr = librosa.load(tmp_file_path, sr=16000)
56
+
57
+ # Clean up temp file
58
+ os.unlink(tmp_file_path)
59
+
60
+ # Check if audio is silent
61
+ if np.max(np.abs(audio_input)) < 0.01:
62
+ return "No speech detected. Please speak louder."
63
+
64
+ # Process audio
65
+ input_values = self.stt_processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values
66
+
67
+ # Perform inference
68
+ with torch.no_grad():
69
+ logits = self.stt_model(input_values).logits
70
+
71
+ # Decode transcription
72
+ predicted_ids = torch.argmax(logits, dim=-1)
73
+ transcription = self.stt_processor.batch_decode(predicted_ids)[0]
74
+
75
+ return transcription.strip() if transcription.strip() else "Could not transcribe audio"
76
+
77
+ except Exception as e:
78
+ logger.error(f"Error in speech-to-text: {e}")
79
+ return f"Error processing audio: {str(e)}"
80
+
81
+ def text_to_speech(self, text, lang='en'):
82
+ """Convert text to speech using gTTS"""
83
+ try:
84
+ # Create TTS object
85
+ tts = gTTS(text=text, lang=lang, slow=False)
86
+
87
+ # Save to temporary file
88
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
89
+ tts.save(tmp_file.name)
90
+ return tmp_file.name
91
+
92
+ except Exception as e:
93
+ logger.error(f"Error in text-to-speech: {e}")
94
+ return None
95
+
96
+ # ============================
97
+ # DATA PREPARATION
98
+ # ============================
99
+
100
+ def prepare_dataset():
101
+ """Load and prepare the emotion dataset"""
102
+ print("πŸ“Š Loading emotion dataset...")
103
+
104
+ # Load the dataset
105
+ ds = load_dataset("cardiffnlp/tweet_eval", "emotion")
106
+
107
+ # Define emotion labels (matching the dataset)
108
+ emotion_labels = ["anger", "joy", "optimism", "sadness"]
109
+
110
+ def clean_text(text):
111
+ """Clean and preprocess text"""
112
+ text = text.lower()
113
+ text = re.sub(r"http\S+", "", text) # remove URLs
114
+ text = re.sub(r"[^\w\s]", "", text) # remove special characters
115
+ text = re.sub(r"\d+", "", text) # remove numbers
116
+ text = re.sub(r"\s+", " ", text) # normalize whitespace
117
+ return text.strip()
118
+
119
+ # Sample and prepare training data
120
+ train_data = ds['train']
121
+ train_sample = random.sample(list(train_data), min(1000, len(train_data)))
122
+
123
+ # Convert to RAG format
124
+ rag_json = []
125
+ for row in train_sample:
126
+ cleaned_text = clean_text(row['text'])
127
+ if len(cleaned_text) > 10: # Filter out very short texts
128
+ rag_json.append({
129
+ "text": cleaned_text,
130
+ "emotion": emotion_labels[row['label']],
131
+ "original_text": row['text']
132
+ })
133
+
134
+ print(f"Dataset prepared with {len(rag_json)} samples")
135
+ return rag_json
136
+
137
+ # ============================
138
+ # EMOTION DETECTION MODEL
139
+ # ============================
140
+
141
+ class EmotionDetector:
142
+ def __init__(self):
143
+ self.model_name = "j-hartmann/emotion-english-distilroberta-base"
144
+
145
+ try:
146
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
147
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
148
+ self.classifier = pipeline(
149
+ "text-classification",
150
+ model=self.model,
151
+ tokenizer=self.tokenizer,
152
+ return_all_scores=False
153
+ )
154
+ except Exception as e:
155
+ st.error(f"❌ Error loading emotion model: {e}")
156
+ raise
157
+
158
+ def detect_emotion(self, text):
159
+ """Detect emotion from text"""
160
+ try:
161
+ result = self.classifier(text)
162
+ emotion = result[0]['label'].lower()
163
+ confidence = result[0]['score']
164
+
165
+ # Map model outputs to our emotion categories
166
+ emotion_mapping = {
167
+ 'anger': 'anger',
168
+ 'disgust': 'sadness',
169
+ 'neutral': 'neutral',
170
+ 'joy': 'joy',
171
+ 'love': 'joy',
172
+ 'happiness': 'joy',
173
+ 'sadness': 'sadness',
174
+ 'fear': 'sadness',
175
+ 'surprise': 'optimism',
176
+ 'optimism': 'optimism'
177
+ }
178
+
179
+ mapped_emotion = emotion_mapping.get(emotion, 'optimism')
180
+ return mapped_emotion, confidence
181
+
182
+ except Exception as e:
183
+ logger.error(f"Error in emotion detection: {e}")
184
+ return 'optimism', 0.5
185
+
186
+ # ============================
187
+ # RAG SYSTEM WITH FAISS
188
+ # ============================
189
+
190
+ class RAGSystem:
191
+ def __init__(self, rag_data):
192
+ self.rag_data = rag_data
193
+ self.texts = [entry['text'] for entry in rag_data]
194
+
195
+ # Initialize embedding model
196
+ self.embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
197
+
198
+ # Create embeddings
199
+ self.embeddings = self.embed_model.encode(
200
+ self.texts,
201
+ convert_to_numpy=True,
202
+ show_progress_bar=False
203
+ )
204
+
205
+ # Create FAISS index
206
+ dimension = self.embeddings.shape[1]
207
+ self.index = faiss.IndexFlatL2(dimension)
208
+ self.index.add(self.embeddings)
209
+
210
+ def retrieve_templates(self, user_input, detected_emotion, top_k=3):
211
+ """Retrieve relevant templates based on emotion and similarity"""
212
+
213
+ # Filter by emotion first
214
+ emotion_filtered_indices = [
215
+ i for i, entry in enumerate(self.rag_data)
216
+ if entry['emotion'] == detected_emotion
217
+ ]
218
+
219
+ if not emotion_filtered_indices:
220
+ emotion_filtered_indices = list(range(len(self.rag_data)))
221
+
222
+ # Get filtered embeddings
223
+ filtered_embeddings = self.embeddings[emotion_filtered_indices]
224
+ filtered_texts = [self.texts[i] for i in emotion_filtered_indices]
225
+
226
+ # Create temporary index for filtered data
227
+ temp_index = faiss.IndexFlatL2(filtered_embeddings.shape[1])
228
+ temp_index.add(filtered_embeddings)
229
+
230
+ # Search for similar templates
231
+ user_embedding = self.embed_model.encode([user_input], convert_to_numpy=True)
232
+ distances, indices = temp_index.search(
233
+ user_embedding,
234
+ min(top_k, len(filtered_texts))
235
+ )
236
+
237
+ # Top templates
238
+ top_templates = [filtered_texts[i] for i in indices[0]]
239
+
240
+ return top_templates
241
+
242
+ # ============================
243
+ # RESPONSE GENERATOR
244
+ # ============================
245
+
246
+ class ResponseGenerator:
247
+ def __init__(self, emotion_detector, rag_system):
248
+ self.emotion_detector = emotion_detector
249
+ self.rag_system = rag_system
250
+
251
+ # Empathetic response templates by emotion
252
+ self.response_templates = {
253
+ 'anger': [
254
+ "I can understand why you're feeling frustrated. It's completely valid to feel this way.",
255
+ "Your anger is understandable. Sometimes situations can be really challenging.",
256
+ "I hear that you're upset, and that's okay. These feelings are important."
257
+ ],
258
+ 'sadness': [
259
+ "I'm sorry you're going through a difficult time. Your feelings are valid.",
260
+ "It sounds like you're dealing with something really tough right now.",
261
+ "I can sense your sadness, and I want you to know that it's okay to feel this way."
262
+ ],
263
+ 'joy': [
264
+ "I'm so happy to hear about your positive experience! That's wonderful.",
265
+ "Your joy is contagious! It's great to hear such positive news.",
266
+ "I love hearing about things that make you happy. That sounds amazing!"
267
+ ],
268
+ 'optimism': [
269
+ "Your positive outlook is inspiring. That's a great way to look at things.",
270
+ "I appreciate your hopeful perspective. That's really encouraging.",
271
+ "It's wonderful to hear your optimistic thoughts. Keep that positive energy!"
272
+ ],
273
+ 'neutral': [
274
+ "Thanks for sharing that. I hear you.",
275
+ "I understand. Let's continue exploring this topic together.",
276
+ "I appreciate you telling me that. Let's keep going."
277
+ ]
278
+ }
279
+
280
+
281
+ def generate_response(self, user_input, top_k=3):
282
+ """Generate empathetic response using RAG and few-shot prompting"""
283
+
284
+ try:
285
+ # Step 1: Detect emotion
286
+ detected_emotion, confidence = self.emotion_detector.detect_emotion(user_input)
287
+
288
+ # Step 2: Retrieve relevant templates
289
+ templates = self.rag_system.retrieve_templates(
290
+ user_input, detected_emotion, top_k=top_k
291
+ )
292
+
293
+ # Step 3: Create response using templates and emotion
294
+ base_responses = self.response_templates.get(
295
+ detected_emotion,
296
+ self.response_templates['optimism']
297
+ )
298
+
299
+ # Combine base response with context from templates
300
+ selected_base = random.choice(base_responses)
301
+
302
+ # Create contextual response
303
+ if templates:
304
+ context_template = random.choice(templates)
305
+ # Enhanced response generation
306
+ response = f"{selected_base} I can relate to what you're sharing - {context_template[:80]}. Remember that your feelings are important and valid."
307
+ else:
308
+ response = selected_base
309
+
310
+ # Add disclaimer
311
+ disclaimer = "\n\n⚠️ This is an automated response. For serious emotional concerns, please consult a mental health professional."
312
+
313
+ return response + disclaimer, detected_emotion, confidence
314
+
315
+ except Exception as e:
316
+ error_msg = f"I apologize, but I encountered an error: {str(e)}"
317
+ disclaimer = "\n\n⚠️ This is an automated response. Please consult a professional if needed."
318
+ return error_msg + disclaimer, 'neutral', 0.0
319
+
320
+ # ============================
321
+ # STREAMLIT APP
322
+ # ============================
323
+
324
+ def main():
325
+ # Page config with better settings
326
+ st.set_page_config(
327
+ page_title="Empathetic AI Companion",
328
+ page_icon="πŸ€–",
329
+ layout="wide",
330
+ initial_sidebar_state="expanded"
331
+ )
332
+
333
+ # CSS with modern design
334
+ st.markdown("""
335
+ <style>
336
+ /* Import Google Fonts */
337
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
338
+
339
+ /* Global styles */
340
+ .stApp {
341
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
342
+ font-family: 'Inter', sans-serif;
343
+ }
344
+
345
+ /* Main header - more elegant */
346
+ .main-header {
347
+ background: rgba(255, 255, 255, 0.15);
348
+ padding: 2rem;
349
+ border-radius: 20px;
350
+ text-align: center;
351
+ margin-bottom: 2rem;
352
+ backdrop-filter: blur(20px);
353
+ border: 1px solid rgba(255, 255, 255, 0.2);
354
+ color: white;
355
+ box-shadow: 0 8px 32px rgba(0,0,0,0.1);
356
+ transition: all 0.3s ease;
357
+ }
358
+
359
+ .main-header:hover {
360
+ transform: translateY(-5px);
361
+ box-shadow: 0 12px 40px rgba(0,0,0,0.2);
362
+ }
363
+
364
+ .main-header h1 {
365
+ font-size: 2.5rem;
366
+ font-weight: 700;
367
+ margin-bottom: 0.5rem;
368
+ background: linear-gradient(45deg, #fff, #f0f0f0);
369
+ -webkit-background-clip: text;
370
+ -webkit-text-fill-color: transparent;
371
+ }
372
+
373
+ .main-header p {
374
+ font-size: 1.2rem;
375
+ opacity: 0.9;
376
+ font-weight: 400;
377
+ margin: 0;
378
+ }
379
+
380
+
381
+ /* Improved chat messages */
382
+ .chat-message {
383
+ margin-bottom: 1.5rem;
384
+ animation: fadeInUp 0.5s ease;
385
+ }
386
+
387
+ @keyframes fadeInUp {
388
+ from { opacity: 0; transform: translateY(20px); }
389
+ to { opacity: 1; transform: translateY(0); }
390
+ }
391
+
392
+ .user-message {
393
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
394
+ color: white;
395
+ padding: 1rem 1.5rem;
396
+ border-radius: 20px 20px 5px 20px;
397
+ margin-left: auto;
398
+ margin-right: 0;
399
+ max-width: 75%;
400
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
401
+ font-weight: 500;
402
+ line-height: 1.5;
403
+ }
404
+
405
+ .bot-message {
406
+ background: linear-gradient(to top, #a18cd1 0%, #fbc2eb 100%);;
407
+ color: white;
408
+ padding: 1rem 1.5rem;
409
+ border-radius: 20px 20px 20px 5px;
410
+ margin-left: 0;
411
+ margin-right: auto;
412
+ max-width: 75%;
413
+ box-shadow: 0 4px 15px rgba(240, 147, 251, 0.3);
414
+ font-weight: 500;
415
+ line-height: 1.5;
416
+ }
417
+
418
+ /* Message headers */
419
+ .message-header {
420
+ font-size: 0.85rem;
421
+ opacity: 0.9;
422
+ margin-bottom: 0.5rem;
423
+ font-weight: 600;
424
+ }
425
+
426
+ /* Emotion badges - hidden but styled */
427
+ .emotion-badge {
428
+ display: inline-block;
429
+ padding: 0.2rem 0.6rem;
430
+ border-radius: 12px;
431
+ font-size: 0.75rem;
432
+ font-weight: 600;
433
+ margin-left: 0.5rem;
434
+ opacity: 0.8;
435
+ }
436
+
437
+
438
+
439
+ /* Enhanced buttons */
440
+ .stButton > button {
441
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
442
+ color: white !important;
443
+ border: none !important;
444
+ border-radius: 50px !important;
445
+ padding: 1rem 2rem !important;
446
+ font-weight: 600 !important;
447
+ font-size: 1rem !important;
448
+ transition: all 0.3s ease !important;
449
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.3) !important;
450
+ min-height: 50px !important;
451
+ }
452
+
453
+ .stButton > button:hover {
454
+ transform: translateY(-3px) !important;
455
+ box-shadow: 0 8px 25px rgba(102, 126, 234, 0.4) !important;
456
+ background: linear-gradient(135deg, #7c8ff0 0%, #8a5ab8 100%) !important;
457
+ }
458
+
459
+ /* Play button styling */
460
+ .play-button {
461
+ background: linear-gradient(135deg, #28a745 0%, #20c997 100%) !important;
462
+ border-radius: 25px !important;
463
+ padding: 0.5rem 1rem !important;
464
+ font-size: 0.9rem !important;
465
+ margin-top: 0.5rem !important;
466
+ box-shadow: 0 4px 15px rgba(40, 167, 69, 0.3) !important;
467
+ }
468
+
469
+ /* Sidebar enhancements */
470
+ .css-1d391kg {
471
+ background: rgba(255, 255, 255, 0.1) !important;
472
+ backdrop-filter: blur(20px) !important;
473
+ }
474
+
475
+
476
+ /* Stats and metrics */
477
+ .metric-card {
478
+ background: rgba(255, 255, 255, 0.9);
479
+ padding: 1.5rem;
480
+ border-radius: 15px;
481
+ text-align: center;
482
+ box-shadow: 0 4px 15px rgba(0,0,0,0.05);
483
+ margin-bottom: 1rem;
484
+ transition: transform 0.3s ease;
485
+ }
486
+
487
+ .metric-card:hover {
488
+ transform: translateY(-3px);
489
+ }
490
+
491
+ /* Progress bars */
492
+ .stProgress > div > div > div {
493
+ background: linear-gradient(90deg, #667eea, #764ba2) !important;
494
+ border-radius: 10px !important;
495
+ }
496
+
497
+ /* Hide default Streamlit elements */
498
+ .stDeployButton {display: none;}
499
+ footer {visibility: hidden;}
500
+ .stApp > header {visibility: hidden;}
501
+
502
+ /* Custom scrollbar */
503
+ .chat-container::-webkit-scrollbar {
504
+ width: 6px;
505
+ }
506
+
507
+
508
+ /* πŸ”Š Audio recorder container fix */
509
+ .audio-recorder-container {
510
+ background: transparent !important;
511
+ border: none !important;
512
+ box-shadow: none !important;
513
+ padding: 0 !important;
514
+ margin: 0 !important;
515
+ }
516
+
517
+ /* 🎀 Recorder button style */
518
+ .audio-recorder-container button {
519
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
520
+ color: #fff !important;
521
+ border: none !important;
522
+ border-radius: 50% !important; /* Makes it a perfect circle */
523
+ width: 60px !important;
524
+ height: 60px !important;
525
+ font-size: 1.2rem !important;
526
+ font-weight: bold !important;
527
+ cursor: pointer !important;
528
+ box-shadow: 0 4px 12px rgba(0,0,0,0.25) !important;
529
+ transition: all 0.3s ease !important;
530
+ }
531
+
532
+ /* Hover effect */
533
+ .audio-recorder-container button:hover {
534
+ transform: scale(1.08);
535
+ box-shadow: 0 6px 18px rgba(0,0,0,0.35) !important;
536
+ }
537
+
538
+
539
+ </style>
540
+ """, unsafe_allow_html=True)
541
+
542
+ # Enhanced Header with animation
543
+ st.markdown("""
544
+ <div class="main-header">
545
+ <h1>πŸ€– Empathetic AI Companion</h1>
546
+ <p>Your intelligent partner for emotional support and meaningful conversations</p>
547
+ </div>
548
+ """, unsafe_allow_html=True)
549
+
550
+ # Initialize session state
551
+ if "chat_history" not in st.session_state:
552
+ st.session_state.chat_history = []
553
+
554
+ if "initialized" not in st.session_state:
555
+ initialize_chatbot()
556
+
557
+ if "audio_processor" not in st.session_state:
558
+ st.session_state.audio_processor = AudioProcessor()
559
+
560
+ if "last_transcription" not in st.session_state:
561
+ st.session_state.last_transcription = ""
562
+
563
+ # Enhanced Sidebar
564
+ with st.sidebar:
565
+ st.markdown("### πŸŽ›οΈ Control Panel")
566
+
567
+ # Voice Settings Section
568
+ with st.expander("πŸŽ™οΈ Voice Settings", expanded=True):
569
+ tts_language = st.selectbox(
570
+ "Text-to-Speech Language",
571
+ options=['en', 'es', 'fr', 'de', 'it'],
572
+ index=0,
573
+ help="Choose your preferred TTS accent"
574
+ )
575
+ st.session_state.tts_language = tts_language
576
+
577
+ auto_tts = st.toggle(
578
+ "Auto-play Bot Responses",
579
+ value=False,
580
+ help="Automatically play TTS for all bot responses"
581
+ )
582
+ st.session_state.auto_tts = auto_tts
583
+
584
+ st.divider()
585
+
586
+ # Enhanced Statistics Section
587
+ if st.session_state.chat_history:
588
+ with st.expander("πŸ“Š Session Analytics", expanded=False):
589
+ emotions = [chat['emotion'] for chat in st.session_state.chat_history if 'emotion' in chat]
590
+ if emotions:
591
+ emotion_counts = {}
592
+ for emotion in emotions:
593
+ emotion_counts[emotion] = emotion_counts.get(emotion, 0) + 1
594
+
595
+ # Display emotion distribution
596
+ for emotion, count in emotion_counts.items():
597
+ percentage = (count / len(emotions)) * 100
598
+ st.metric(
599
+ f"{emotion.title()}",
600
+ f"{count} messages",
601
+ f"{percentage:.1f}%"
602
+ )
603
+
604
+ # Quick Actions
605
+ with st.expander("⚑ Quick Actions", expanded=True):
606
+ col1, col2 = st.columns(2)
607
+
608
+ with col1:
609
+ if st.button("πŸ§ͺ Test AI", use_container_width=True):
610
+ test_emotion_detection()
611
+
612
+ with col2:
613
+ if st.button("πŸ—‘οΈ Clear Chat", use_container_width=True):
614
+ st.session_state.chat_history = []
615
+ st.session_state.last_transcription = ""
616
+ st.rerun()
617
+
618
+ st.divider()
619
+
620
+ # Sample Messages - More engaging
621
+ with st.expander("πŸ’‘ Try These Messages", expanded=False):
622
+ sample_messages = [
623
+ ("😊", "I'm feeling really happy today!"),
624
+ ("😀", "I'm so frustrated with everything"),
625
+ ("😒", "I feel really sad and alone"),
626
+ ("🌟", "I'm excited about my future!")
627
+ ]
628
+
629
+ for i, (emoji, msg) in enumerate(sample_messages):
630
+ if st.button(f"{emoji} {msg[:20]}...", key=f"sample_{i}", use_container_width=True):
631
+ process_message(msg)
632
+ st.rerun()
633
+
634
+ st.divider()
635
+
636
+ # Enhanced Info Section
637
+ st.markdown("""
638
+ <div style="background: rgba(255,255,255,0.1); padding: 1rem; border-radius: 10px; backdrop-filter: blur(10px);">
639
+ <h4 style="color: white; margin-bottom: 0.5rem;">✨ Features</h4>
640
+ <ul style="color: rgba(255,255,255,0.9); font-size: 0.9rem; margin: 0;">
641
+ <li>🎀 Voice Recording & STT</li>
642
+ <li>πŸ”Š Natural TTS Responses</li>
643
+ <li>😊 Advanced Emotion AI</li>
644
+ <li>πŸ’¬ Context-Aware Replies</li>
645
+ <li>πŸ“Š Real-time Analytics</li>
646
+ </ul>
647
+ </div>
648
+ """, unsafe_allow_html=True)
649
+
650
+ # Main Layout - Improved
651
+ col_main, col_stats = st.columns([7, 3])
652
+
653
+ with col_main:
654
+ # Enhanced Chat Display
655
+ st.markdown('<div class="chat-container">', unsafe_allow_html=True)
656
+
657
+ if st.session_state.chat_history:
658
+ for i, chat in enumerate(st.session_state.chat_history[-15:]): # Show more messages
659
+ # User message with better styling
660
+ st.markdown(f"""
661
+ <div class="chat-message">
662
+ <div class="user-message">
663
+ <div class="message-header">πŸ§‘ You β€’ {chat['timestamp']}</div>
664
+ {chat['user']}
665
+ </div>
666
+ </div>
667
+ """, unsafe_allow_html=True)
668
+
669
+ # Bot response with enhanced styling
670
+ emotion_class = chat.get('emotion', 'optimism')
671
+ confidence = chat.get('confidence', 0.0)
672
+
673
+ st.markdown(f"""
674
+ <div class="chat-message">
675
+ <div class="bot-message">
676
+ <div class="message-header">
677
+ πŸ€– AI Assistant
678
+ <span class="emotion-badge {emotion_class}">
679
+ {emotion_class.title()} {confidence:.0%}
680
+ </span>
681
+ </div>
682
+ {chat['bot'].replace('⚠️', '⚠️ ')}
683
+ </div>
684
+ </div>
685
+ """, unsafe_allow_html=True)
686
+
687
+ # Enhanced TTS button
688
+ col_tts, col_spacer = st.columns([2, 6])
689
+ with col_tts:
690
+ if st.button(f"πŸ”Š Play Audio", key=f"tts_{i}", help="Listen to response"):
691
+ play_tts(chat['bot'])
692
+
693
+ # Auto-play logic
694
+ if (st.session_state.auto_tts and
695
+ i == len(st.session_state.chat_history) - 1 and
696
+ chat.get('should_play_tts', False)):
697
+ play_tts(chat['bot'])
698
+ st.session_state.chat_history[-1]['should_play_tts'] = False
699
+
700
+ # Enhanced Input Section
701
+ st.markdown('<div class="input-section">', unsafe_allow_html=True)
702
+
703
+ # Input layout
704
+ col_text = st.container()
705
+ col_voice, col_send = st.columns(2)
706
+
707
+
708
+ with col_text:
709
+ user_input = st.text_input(
710
+ "",
711
+ placeholder="Share what's on your mind... How can I help you today?",
712
+ label_visibility="collapsed",
713
+ key="main_input"
714
+ )
715
+ from st_audiorec import st_audiorec
716
+ with col_voice:
717
+ audio_file = st.audio_input("Record a voice message")
718
+ audio_bytes = None
719
+ if audio_file is not None:
720
+ # Convert to bytes
721
+ audio_bytes = audio_file.read()
722
+ # Play it back
723
+ st.audio(audio_bytes, format="audio/wav")
724
+
725
+ with col_send:
726
+ if st.button("πŸ“€ Send Message", type="primary", key="send_btn", use_container_width=True):
727
+ if user_input.strip():
728
+ process_message(user_input.strip())
729
+ st.rerun()
730
+
731
+ # Voice processing with better feedback
732
+ if audio_bytes is not None:
733
+ with st.spinner("πŸ”„ Processing your voice..."):
734
+ transcription = st.session_state.audio_processor.speech_to_text_from_bytes(audio_bytes)
735
+
736
+ if transcription and transcription not in ["No speech detected. Please speak louder.", "Could not transcribe audio"]:
737
+ st.success(f"πŸŽ™οΈ **Transcribed:** \"{transcription}\"")
738
+
739
+ if transcription != st.session_state.last_transcription:
740
+ st.session_state.last_transcription = transcription
741
+ process_message(transcription, from_voice=True)
742
+ st.rerun()
743
+ else:
744
+ st.warning(f"⚠️ {transcription}")
745
+
746
+ st.markdown('</div>', unsafe_allow_html=True)
747
+
748
+ # Enhanced Statistics Panel
749
+ with col_stats:
750
+ if st.session_state.chat_history:
751
+ st.markdown("### πŸ“ˆ Live Insights")
752
+
753
+ # Emotion trends
754
+ recent_emotions = [
755
+ chat.get('emotion', 'optimism')
756
+ for chat in st.session_state.chat_history[-10:]
757
+ if 'emotion' in chat
758
+ ]
759
+
760
+ if recent_emotions:
761
+ st.markdown("**Recent Emotions:**")
762
+ emotion_scores = {'anger': 0, 'sadness': 0, 'joy': 0, 'optimism': 0}
763
+
764
+ for emotion in recent_emotions:
765
+ emotion_scores[emotion] = emotion_scores.get(emotion, 0) + 1
766
+
767
+ total = len(recent_emotions)
768
+ for emotion, count in emotion_scores.items():
769
+ if count > 0:
770
+ progress = count / total
771
+ st.progress(progress, text=f"{emotion.title()}: {count}/{total}")
772
+
773
+ # Session metrics
774
+ if len(st.session_state.chat_history) > 2:
775
+ st.divider()
776
+ st.markdown("**Session Overview:**")
777
+
778
+ total_messages = len(st.session_state.chat_history)
779
+ emotions = [chat.get('emotion', 'optimism') for chat in st.session_state.chat_history]
780
+
781
+ # Metrics cards
782
+ st.metric("Messages", total_messages)
783
+
784
+ if emotions:
785
+ most_common = max(set(emotions), key=emotions.count)
786
+ st.metric("Dominant Emotion", most_common.title())
787
+
788
+ # Mood indicator
789
+ positive_emotions = ['joy', 'optimism']
790
+ positive_count = sum(1 for e in emotions if e in positive_emotions)
791
+ mood_score = positive_count / len(emotions)
792
+
793
+ if mood_score > 0.6:
794
+ st.success("😊 Positive Mood")
795
+ elif mood_score > 0.4:
796
+ st.info("😐 Balanced Mood")
797
+ else:
798
+ st.warning("πŸ˜” Needs Support")
799
+ else:
800
+ # Getting started tips
801
+ st.markdown("""
802
+ ### πŸš€ Getting Started
803
+
804
+ **Tips for better conversations:**
805
+ - Be specific about your feelings
806
+ - Share context about your situation
807
+ - Use voice input for natural interaction
808
+ - Try the sample messages below
809
+
810
+ **Privacy Note:**
811
+ Your conversations are processed locally and not stored permanently.
812
+ """)
813
+
814
+ def initialize_chatbot():
815
+ """Initialize the chatbot components with better feedback"""
816
+ with st.spinner("πŸš€ Loading AI models..."):
817
+ try:
818
+ progress_bar = st.progress(0)
819
+ status_text = st.empty()
820
+
821
+ # Load dataset
822
+ status_text.text("πŸ“Š Loading emotion dataset...")
823
+ progress_bar.progress(25)
824
+ st.session_state.rag_data = prepare_dataset()
825
+
826
+ # Initialize emotion detector
827
+ status_text.text("🧠 Loading emotion detection model...")
828
+ progress_bar.progress(50)
829
+ st.session_state.emotion_detector = EmotionDetector()
830
+
831
+ # Initialize RAG system
832
+ status_text.text("πŸ” Setting up knowledge retrieval...")
833
+ progress_bar.progress(75)
834
+ st.session_state.rag_system = RAGSystem(st.session_state.rag_data)
835
+
836
+ # Initialize response generator
837
+ status_text.text("πŸ’¬ Preparing response generation...")
838
+ progress_bar.progress(100)
839
+ st.session_state.response_generator = ResponseGenerator(
840
+ st.session_state.emotion_detector,
841
+ st.session_state.rag_system
842
+ )
843
+
844
+ st.session_state.initialized = True
845
+
846
+ # Clear loading elements
847
+ progress_bar.empty()
848
+ status_text.empty()
849
+
850
+ st.success("βœ… AI Companion ready! Start your conversation below.")
851
+
852
+ except Exception as e:
853
+ st.error(f"❌ Failed to initialize: {str(e)}")
854
+ st.info("πŸ’‘ Try refreshing the page or check your internet connection.")
855
+ st.stop()
856
+
857
+ def process_message(user_input, from_voice=False):
858
+ """Enhanced message processing with better error handling"""
859
+ if not user_input.strip():
860
+ return
861
+
862
+ try:
863
+ # Show typing indicator
864
+ with st.spinner("πŸ€– AI is thinking..."):
865
+ # Generate response
866
+ bot_response, detected_emotion, confidence = st.session_state.response_generator.generate_response(
867
+ user_input,
868
+ top_k=3
869
+ )
870
+
871
+ # Create chat entry
872
+ chat_entry = {
873
+ 'user': user_input,
874
+ 'bot': bot_response,
875
+ 'emotion': detected_emotion,
876
+ 'confidence': confidence,
877
+ 'timestamp': datetime.now().strftime("%H:%M"),
878
+ 'from_voice': from_voice,
879
+ 'should_play_tts': st.session_state.get('auto_tts', False)
880
+ }
881
+
882
+ st.session_state.chat_history.append(chat_entry)
883
+
884
+ # Log interaction
885
+ logger.info(f"User ({'Voice' if from_voice else 'Text'}): {user_input[:50]}... | Emotion: {detected_emotion} ({confidence:.2f})")
886
+
887
+ except Exception as e:
888
+ st.error(f"❌ Something went wrong: {str(e)}")
889
+ st.info("πŸ’‘ Please try again or rephrase your message.")
890
+ logger.error(f"Processing error: {e}")
891
+
892
+ def play_tts(text):
893
+ """Enhanced TTS with better error handling"""
894
+ try:
895
+ # Clean text for TTS
896
+ clean_text = re.sub(r'[^\w\s\.\,\!\?\']', '', text)
897
+ clean_text = clean_text.replace('⚠️', '').strip()
898
+
899
+ if not clean_text:
900
+ return
901
+
902
+ # Generate TTS
903
+ tts_lang = st.session_state.get('tts_language', 'en')
904
+
905
+ with st.spinner("πŸ”Š Generating audio..."):
906
+ audio_file = st.session_state.audio_processor.text_to_speech(
907
+ clean_text[:500], # Limit length
908
+ lang=tts_lang
909
+ )
910
+
911
+ if audio_file:
912
+ with open(audio_file, 'rb') as f:
913
+ audio_bytes = f.read()
914
+
915
+ st.audio(audio_bytes, format='audio/mp3', autoplay=True)
916
+ os.unlink(audio_file) # Clean up
917
+
918
+ except Exception as e:
919
+ logger.error(f"TTS error: {e}")
920
+ st.toast("⚠️ Could not generate audio", icon="πŸ”Š")
921
+
922
+ def test_emotion_detection():
923
+ """Enhanced emotion testing with better display"""
924
+ test_texts = [
925
+ "I'm absolutely thrilled about my new promotion!",
926
+ "I feel completely overwhelmed and sad today",
927
+ "This traffic is making me so angry and frustrated!",
928
+ "I have hope that everything will work out perfectly"
929
+ ]
930
+
931
+ st.markdown("### πŸ§ͺ Emotion Detection Demo")
932
+
933
+ for i, text in enumerate(test_texts):
934
+ with st.container():
935
+ emotion, confidence = st.session_state.emotion_detector.detect_emotion(text)
936
+
937
+ col1, col2 = st.columns([3, 1])
938
+ with col1:
939
+ st.write(f"**Text:** {text}")
940
+ st.write(f"**Detected:** {emotion.title()} ({confidence:.1%} confidence)")
941
+ with col2:
942
+ # Emotion emoji mapping
943
+ emoji_map = {'anger': '😠', 'sadness': '😒', 'joy': '😊', 'optimism': '🌟'}
944
+ st.markdown(f"### {emoji_map.get(emotion, 'πŸ€”')}")
945
+
946
+ if i < len(test_texts) - 1:
947
+ st.divider()
948
+
949
+ if __name__ == "__main__":
950
+ main()