Spaces:
Sleeping
Sleeping
| from flask import Flask, request, send_file, abort | |
| import requests | |
| import io | |
| from PIL import Image | |
| from cachetools import TTLCache, cached | |
| import random | |
| import os | |
| import urllib.parse | |
| import hashlib | |
| from deep_translator import GoogleTranslator | |
| from langdetect import detect | |
| app = Flask(__name__) | |
| # Максимальные значения для ширины и высоты | |
| MAX_WIDTH = 1384 | |
| MAX_HEIGHT = 1384 | |
| # Кэш на 10 минут | |
| cache = TTLCache(maxsize=100, ttl=600) | |
| # Получаем ключи из переменной окружения | |
| keys = os.getenv("keys", "").split(',') | |
| if not keys: | |
| raise ValueError("Environment variable 'keys' must be set with a comma-separated list of API keys.") | |
| def get_random_key(): | |
| return random.choice(keys) | |
| def generate_cache_key(prompt, width, height, seed, model_name): | |
| # Создаем уникальный ключ на основе всех параметров | |
| return hashlib.md5(f"{prompt}_{width}_{height}_{seed}_{model_name}".encode()).hexdigest() | |
| def scale_dimensions(width, height, max_width, max_height): | |
| """Масштабирует размеры изображения, сохраняя соотношение сторон, и округляет до чисел, кратных 8.""" | |
| aspect_ratio = width / height | |
| if width > max_width or height > max_height: | |
| if width / max_width > height / max_height: | |
| width = max_width | |
| height = int(width / aspect_ratio) | |
| else: | |
| height = max_height | |
| width = int(height * aspect_ratio) | |
| # Округляем до ближайших чисел, кратных 8 | |
| width = (width + 3) // 8 * 8 | |
| height = (height + 3) // 8 * 8 | |
| return width, height | |
| def generate_cached_image(prompt, width, height, seed, model_name, api_key): | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "width": width, | |
| "height": height, | |
| "seed": seed | |
| } | |
| } | |
| try: | |
| response = requests.post( | |
| f"https://api-inference.huggingface.co/models/{model_name}", | |
| headers=headers, | |
| json=data, | |
| timeout=1550 # Таймаут 3 минуты | |
| ) | |
| response.raise_for_status() | |
| image_data = response.content | |
| image = Image.open(io.BytesIO(image_data)) | |
| return image | |
| except requests.exceptions.HTTPError as http_err: | |
| app.logger.error(f"HTTP error occurred: {http_err} - Response: {response.text}") | |
| return None | |
| except requests.exceptions.Timeout as timeout_err: | |
| app.logger.error(f"Timeout error occurred: {timeout_err}") | |
| return None | |
| except requests.exceptions.RequestException as req_err: | |
| app.logger.error(f"Request error occurred: {req_err}") | |
| return None | |
| def get_image(prompt): | |
| width = request.args.get('width', type=int, default=512) | |
| height = request.args.get('height', type=int, default=512) | |
| seed = request.args.get('seed', type=int, default=25) | |
| model_name = request.args.get('model', default="black-forest-labs/FLUX.1-schnell").replace('+', '/') | |
| api_key = request.args.get('key', default=None) | |
| # Декодируем URL-кодированный prompt | |
| prompt = urllib.parse.unquote(prompt) | |
| # Определяем язык промпта | |
| try: | |
| language = detect(prompt) | |
| except Exception as e: | |
| app.logger.error(f"Error detecting language: {e}") | |
| return send_error_image() | |
| # Переводим промпт, если он не на английском языке | |
| if language != 'en': | |
| try: | |
| translator = GoogleTranslator(source=language, target='en') | |
| prompt = translator.translate(prompt) | |
| except Exception as e: | |
| app.logger.error(f"Error translating prompt: {e}") | |
| return send_error_image() | |
| # Масштабируем размеры изображения, если они превышают максимальные значения, и округляем до чисел, кратных 8 | |
| width, height = scale_dimensions(width, height, MAX_WIDTH, MAX_HEIGHT) | |
| # Используем указанный ключ, если он предоставлен, иначе выбираем случайный ключ | |
| if api_key is None: | |
| api_key = get_random_key() | |
| try: | |
| image = generate_cached_image(prompt, width, height, seed, model_name, api_key) | |
| if image is None: | |
| return send_error_image() | |
| except Exception as e: | |
| app.logger.error(f"Error generating image: {e}") | |
| return send_error_image() | |
| img_byte_arr = io.BytesIO() | |
| image.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| return send_file( | |
| io.BytesIO(img_byte_arr), | |
| mimetype='image/png' | |
| ) | |
| def health_check(): | |
| return "OK", 200 | |
| def send_error_image(): | |
| error_image_url = "https://raw.githubusercontent.com/Igroshka/-/refs/heads/main/img/nuai/errorimg.png" | |
| try: | |
| response = requests.get(error_image_url) | |
| response.raise_for_status() | |
| error_image = Image.open(io.BytesIO(response.content)) | |
| img_byte_arr = io.BytesIO() | |
| error_image.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| return send_file( | |
| io.BytesIO(img_byte_arr), | |
| mimetype='image/png' | |
| ) | |
| except Exception as e: | |
| app.logger.error(f"Error fetching error image: {e}") | |
| abort(500, description="Error fetching error image") | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860, debug=False) | |