Spaces:
Runtime error
Runtime error
| ############################################################################################################################# | |
| # Filename : app.py | |
| # Description: A Streamlit application to showcase the importance of Responsible AI in LLMs. | |
| # Author : Georgios Ioannou | |
| # | |
| # TODO: Add code for Google Gemma 7b and 7b-it. | |
| # TODO: Write code documentation. | |
| # Copyright © 2024 by Georgios Ioannou | |
| ############################################################################################################################# | |
| # Import libraries. | |
| import os # Load environment variable(s). | |
| import requests # Send HTTP GET request to Hugging Face models for inference. | |
| import streamlit as st # Build the GUI of the application. | |
| import streamlit.components.v1 as components | |
| from dataclasses import dataclass | |
| from dotenv import find_dotenv, load_dotenv # Read local .env file. | |
| from langchain.callbacks import get_openai_callback | |
| from langchain.chains import ConversationChain | |
| from langchain.llms import OpenAI | |
| from policies import complex_policy, simple_policy | |
| from transformers import pipeline # Access to Hugging Face models. | |
| from typing import Literal | |
| ############################################################################################################################# | |
| # Load environment variable(s). | |
| # HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| ############################################################################################################################# | |
| class Message: | |
| """Class for keeping track of a chat message.""" | |
| origin: Literal["human", "ai"] | |
| message: str | |
| ############################################################################################################################# | |
| # Remeber (Keep history) user's message in each rendering of the application. | |
| def initialize_session_state(): | |
| if "history" not in st.session_state: | |
| st.session_state.history = [] | |
| if "token_count" not in st.session_state: | |
| st.session_state.token_count = 0 | |
| if "policy" not in st.session_state: | |
| st.session_state.policy = "No Policy" | |
| if "is_gemma_model" not in st.session_state: | |
| st.session_state.is_gemma_model = False | |
| if "selected_ai_icon" not in st.session_state: | |
| st.session_state.selected_ai_icon = "a1.png" | |
| if "selected_user_icon" not in st.session_state: | |
| st.session_state.selected_user_icon = "man.png" | |
| ############################################################################################################################# | |
| def on_click_callback(): | |
| with get_openai_callback() as cb: | |
| human_prompt = st.session_state.human_prompt | |
| # Policy + User Query Text. | |
| role = "User" # This can be an Agent. | |
| task = f""" | |
| Task: Check if there is unsafe content in | |
| '{role}' messages in conversations | |
| according our safety policy with the below categories. | |
| """ | |
| output_format = f""" | |
| Provide your safety assessment for '{role}' in the above conversation: | |
| - First line must read 'safe' or 'unsafe'. | |
| - If unsafe, a second line must include a comma-separated list of violated categories. | |
| """ | |
| query = human_prompt | |
| conversation = f""" | |
| <BEGIN CONVERSATION> | |
| User: {query} | |
| <END CONVERSATION> | |
| """ | |
| if st.session_state.policy == "Simple Policy": | |
| prompt = f""" | |
| {task} | |
| {simple_policy} | |
| {conversation} | |
| {output_format} | |
| """ | |
| elif st.session_state.policy == "Complex Policy": | |
| prompt = f""" | |
| {task} | |
| {complex_policy} | |
| {conversation} | |
| {output_format} | |
| """ | |
| elif st.session_state.policy == "No Policy": | |
| prompt = human_prompt | |
| # Getting the llm response for safety check 1. | |
| # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b" | |
| if st.session_state.is_gemma_model: | |
| pass | |
| else: | |
| llm_response_safety_check_1 = st.session_state.conversation.run(prompt) | |
| st.session_state.history.append(Message("human", human_prompt)) | |
| st.session_state.token_count += cb.total_tokens | |
| # Checking if response is safe. Safety Check 1. Checking what goes in (user input). | |
| if ( | |
| "unsafe" in llm_response_safety_check_1.lower() | |
| ): # If respone is unsafe return unsafe. | |
| st.session_state.history.append(Message("ai", llm_response_safety_check_1)) | |
| return | |
| else: # If respone is safe answer the question. | |
| if st.session_state.is_gemma_model: | |
| pass | |
| else: | |
| conversation_chain = ConversationChain( | |
| llm=OpenAI( | |
| temperature=0.2, | |
| openai_api_key=OPENAI_API_KEY, | |
| model_name=st.session_state.model, | |
| ), | |
| ) | |
| llm_response = conversation_chain.run(human_prompt) | |
| # st.session_state.history.append(Message("ai", llm_response)) | |
| st.session_state.token_count += cb.total_tokens | |
| # Policy + LLM Response. | |
| query = llm_response | |
| conversation = f""" | |
| <BEGIN CONVERSATION> | |
| User: {query} | |
| <END CONVERSATION> | |
| """ | |
| if st.session_state.policy == "Simple Policy": | |
| prompt = f""" | |
| {task} | |
| {simple_policy} | |
| {conversation} | |
| {output_format} | |
| """ | |
| elif st.session_state.policy == "Complex Policy": | |
| prompt = f""" | |
| {task} | |
| {complex_policy} | |
| {conversation} | |
| {output_format} | |
| """ | |
| elif st.session_state.policy == "No Policy": | |
| prompt = llm_response | |
| # Getting the llm response for safety check 2. | |
| # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b" | |
| if st.session_state.is_gemma_model: | |
| pass | |
| else: | |
| llm_response_safety_check_2 = st.session_state.conversation.run(prompt) | |
| st.session_state.token_count += cb.total_tokens | |
| # Checking if response is safe. Safety Check 2. Checking what goes out (llm output). | |
| if ( | |
| "unsafe" in llm_response_safety_check_2.lower() | |
| ): # If respone is unsafe return. | |
| st.session_state.history.append( | |
| Message( | |
| "ai", | |
| "THIS FROM THE AUTHOR OF THE CODE: LLM WANTED TO RESPOND UNSAFELY!", | |
| ) | |
| ) | |
| else: | |
| st.session_state.history.append(Message("ai", llm_response)) | |
| ############################################################################################################################# | |
| # Function to apply local CSS. | |
| def local_css(file_name): | |
| with open(file_name) as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| ############################################################################################################################# | |
| # Main function to create the Streamlit web application. | |
| def main(): | |
| # try: | |
| initialize_session_state() | |
| # Page title and favicon. | |
| st.set_page_config(page_title="Responsible AI", page_icon="⚖️") | |
| # Load CSS. | |
| local_css("./static/styles/styles.css") | |
| # Title. | |
| title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem"> | |
| Responsible AI</h1>""" | |
| st.markdown(title, unsafe_allow_html=True) | |
| # Subtitle 1. | |
| title = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem"> | |
| Showcase the importance of Responsible AI in LLMs</h3>""" | |
| st.markdown(title, unsafe_allow_html=True) | |
| # Subtitle 2. | |
| title = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem"> | |
| CUNY Tech Prep Tutorial 6</h2>""" | |
| st.markdown(title, unsafe_allow_html=True) | |
| # Image. | |
| image = "./static/ctp.png" | |
| left_co, cent_co, last_co = st.columns(3) | |
| with cent_co: | |
| st.image(image=image) | |
| # Sidebar dropdown menu for Models. | |
| models = [ | |
| "gpt-4-turbo", | |
| "gpt-4", | |
| "gpt-3.5-turbo", | |
| "gpt-3.5-turbo-instruct", | |
| "gemma-7b", | |
| "gemma-7b-it", | |
| ] | |
| selected_model = st.sidebar.selectbox("Select Model:", models) | |
| st.sidebar.write(f"Current Model: {selected_model}") | |
| if selected_model == "gpt-4-turbo": | |
| st.session_state.model = "gpt-4-turbo" | |
| elif selected_model == "gpt-4": | |
| st.session_state.model = "gpt-4" | |
| elif selected_model == "gpt-3.5-turbo": | |
| st.session_state.model = "gpt-3.5-turbo" | |
| elif selected_model == "gpt-3.5-turbo-instruct": | |
| st.session_state.model = "gpt-3.5-turbo-instruct" | |
| elif selected_model == "gemma-7b": | |
| st.session_state.model = "gemma-7b" | |
| elif selected_model == "gemma-7b-it": | |
| st.session_state.model = "gemma-7b-it" | |
| if "gpt" in st.session_state.model: | |
| st.session_state.conversation = ConversationChain( | |
| llm=OpenAI( | |
| temperature=0.2, | |
| openai_api_key=OPENAI_API_KEY, | |
| model_name=st.session_state.model, | |
| ), | |
| ) | |
| elif "gemma" in st.session_state.model: | |
| # Load model from Hugging Face. | |
| st.session_state.is_gemma_model = True | |
| pass | |
| # Sidebar dropdown menu for Policies. | |
| policies = ["No Policy", "Complex Policy", "Simple Policy"] | |
| selected_policy = st.sidebar.selectbox("Select Policy:", policies) | |
| st.sidebar.write(f"Current Policy: {selected_policy}") | |
| if selected_policy == "No Policy": | |
| st.session_state.policy = "No Policy" | |
| elif selected_policy == "Complex Policy": | |
| st.session_state.policy = "Complex Policy" | |
| elif selected_policy == "Simple Policy": | |
| st.session_state.policy = "Simple Policy" | |
| # Sidebar dropdown menu for AI Icons. | |
| ai_icons = ["AI 1", "AI 2"] | |
| selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons) | |
| st.sidebar.write(f"Current AI Icon: {selected_ai_icon}") | |
| if selected_ai_icon == "AI 1": | |
| st.session_state.selected_ai_icon = "ai1.png" | |
| elif selected_ai_icon == "AI 2": | |
| st.session_state.selected_ai_icon = "ai2.png" | |
| # Sidebar dropdown menu for User Icons. | |
| user_icons = ["Man", "Woman"] | |
| selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons) | |
| st.sidebar.write(f"Current User Icon: {selected_user_icon}") | |
| if selected_user_icon == "Man": | |
| st.session_state.selected_user_icon = "man.png" | |
| elif selected_user_icon == "Woman": | |
| st.session_state.selected_user_icon = "woman.png" | |
| # Placeholder for the chat messages. | |
| chat_placeholder = st.container() | |
| # Placeholder for the user input. | |
| prompt_placeholder = st.form("chat-form") | |
| token_placeholder = st.empty() | |
| with chat_placeholder: | |
| for chat in st.session_state.history: | |
| div = f""" | |
| <div class="chat-row | |
| {'' if chat.origin == 'ai' else 'row-reverse'}"> | |
| <img class="chat-icon" src="app/static/{ | |
| st.session_state.selected_ai_icon if chat.origin == 'ai' | |
| else st.session_state.selected_user_icon}" | |
| width=32 height=32> | |
| <div class="chat-bubble | |
| {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}"> | |
| ​{chat.message} | |
| </div> | |
| </div> | |
| """ | |
| st.markdown(div, unsafe_allow_html=True) | |
| for _ in range(3): | |
| st.markdown("") | |
| # User prompt. | |
| with prompt_placeholder: | |
| st.markdown("**Chat**") | |
| cols = st.columns((6, 1)) | |
| # Large text input in the left column. | |
| cols[0].text_input( | |
| "Chat", | |
| placeholder="What is your question?", | |
| label_visibility="collapsed", | |
| key="human_prompt", | |
| ) | |
| # Red button in the right column. | |
| cols[1].form_submit_button( | |
| "Submit", | |
| type="primary", | |
| on_click=on_click_callback, | |
| ) | |
| token_placeholder.caption( | |
| f""" | |
| Used {st.session_state.token_count} tokens \n | |
| """ | |
| ) | |
| # GitHub repository of author. | |
| st.markdown( | |
| f""" | |
| <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our | |
| <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b> | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Use the Enter key in the keyborad to click on the Submit button. | |
| components.html( | |
| """ | |
| <script> | |
| const streamlitDoc = window.parent.document; | |
| const buttons = Array.from( | |
| streamlitDoc.querySelectorAll('.stButton > button') | |
| ); | |
| const submitButton = buttons.find( | |
| el => el.innerText === 'Submit' | |
| ); | |
| streamlitDoc.addEventListener('keydown', function(e) { | |
| switch (e.key) { | |
| case 'Enter': | |
| submitButton.click(); | |
| break; | |
| } | |
| }); | |
| </script> | |
| """, | |
| height=0, | |
| width=0, | |
| ) | |
| ############################################################################################################################# | |
| if __name__ == "__main__": | |
| main() | |