Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoProcessor, AutoModel | |
| import scipy.io.wavfile as wavfile | |
| import spaces | |
| # Processor | |
| def load_model(): | |
| processor = AutoProcessor.from_pretrained("suno/bark-small") | |
| model = AutoModel.from_pretrained("suno/bark-small") | |
| model.eval() # Set the model to evaluation mode | |
| return processor, model | |
| # Load models on startup | |
| print("Loading models...") | |
| processor, model = load_model() | |
| print("Models loaded successfully!") | |
| # Decorate the function to enable GPU usage | |
| def text_to_speech(text): | |
| try: | |
| # Check if a GPU is available and set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Move model to GPU | |
| model.to(device) | |
| inputs = processor( | |
| text=[text], | |
| return_tensors="pt", | |
| ).to(device) # Move inputs to GPU | |
| # Generate speech values on the GPU | |
| with torch.no_grad(): # Disable gradient calculation for inference | |
| speech_values = model.generate(**inputs, do_sample=True) | |
| # Move generated audio data back to CPU for saving | |
| audio_data = speech_values.cpu().numpy().squeeze() | |
| sampling_rate = model.generation_config.sample_rate | |
| temp_path = "temp_audio.wav" | |
| wavfile.write(temp_path, sampling_rate, audio_data) | |
| return temp_path | |
| except Exception as e: | |
| return f"Error generating speech: {str(e)}" | |
| # Define Gradio interface | |
| demo = gr.Interface( | |
| fn=text_to_speech, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Enter text", | |
| placeholder="दिल्ली मेट्रो में आपका स्वागत है" | |
| ) | |
| ], | |
| outputs=gr.Audio(label="Generated Speech"), | |
| title="Bark TTS Test App", | |
| description="This app generates speech from text using the Bark TTS model.", | |
| examples=[ | |
| ["दिल्ली मेट्रो में आपका स्वागत है"], | |
| ["अगला स्टेशन राजीव चौक है"] | |
| ], | |
| theme="default" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |