Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import os | |
| from einops import rearrange | |
| import gc | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import os | |
| from einops import rearrange | |
| from stable_audio_tools import get_pretrained_model | |
| from stable_audio_tools.inference.generation import generate_diffusion_cond | |
| from stable_audio_tools.data.utils import read_video, merge_video_audio, load_and_process_audio | |
| import stat | |
| import platform | |
| import logging | |
| from transformers import logging as transformers_logging | |
| transformers_logging.set_verbosity_error() | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| model, model_config = get_pretrained_model('HKUSTAudio/AudioX') | |
| sample_rate = model_config["sample_rate"] | |
| sample_size = model_config["sample_size"] | |
| TEMP_DIR = "tmp/gradio" | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| os.chmod(TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) | |
| VIDEO_TEMP_DIR = os.path.join(TEMP_DIR, "videos") | |
| os.makedirs(VIDEO_TEMP_DIR, exist_ok=True) | |
| os.chmod(VIDEO_TEMP_DIR, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) | |
| def generate_cond( | |
| prompt, | |
| negative_prompt=None, | |
| video_file=None, | |
| audio_prompt_file=None, | |
| audio_prompt_path=None, | |
| seconds_start=0, | |
| seconds_total=10, | |
| cfg_scale=7.0, | |
| steps=100, | |
| preview_every=0, | |
| seed=-1, | |
| sampler_type="dpmpp-3m-sde", | |
| sigma_min=0.03, | |
| sigma_max=500, | |
| cfg_rescale=0.0, | |
| use_init=False, | |
| init_audio=None, | |
| init_noise_level=0.1, | |
| mask_cropfrom=None, | |
| mask_pastefrom=None, | |
| mask_pasteto=None, | |
| mask_maskstart=None, | |
| mask_maskend=None, | |
| mask_softnessL=None, | |
| mask_softnessR=None, | |
| mask_marination=None, | |
| batch_size=1 | |
| ): | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print(f"Prompt: {prompt}") | |
| preview_images = [] | |
| if preview_every == 0: | |
| preview_every = None | |
| try: | |
| has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() | |
| except Exception: | |
| has_mps = False | |
| if has_mps: | |
| device = torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| global model | |
| model = model.to(device) | |
| target_fps = model_config.get("video_fps", 5) | |
| model_type = model_config.get("model_type", "diffusion_cond") | |
| if video_file is not None: | |
| actual_video_path = video_file['name'] if isinstance(video_file, dict) else video_file.name | |
| else: | |
| actual_video_path = None | |
| if audio_prompt_file is not None: | |
| audio_path = audio_prompt_file.name | |
| elif audio_prompt_path: | |
| audio_path = audio_prompt_path.strip() | |
| else: | |
| audio_path = None | |
| Video_tensors = read_video(actual_video_path, seek_time=seconds_start, duration=seconds_total, target_fps=target_fps) | |
| audio_tensor = load_and_process_audio(audio_path, sample_rate, seconds_start, seconds_total) | |
| audio_tensor = audio_tensor.to(device) | |
| seconds_input = sample_size / sample_rate | |
| if not prompt: | |
| prompt = "" | |
| conditioning = [{ | |
| "video_prompt": [Video_tensors.unsqueeze(0)], | |
| "text_prompt": prompt, | |
| "audio_prompt": audio_tensor.unsqueeze(0), | |
| "seconds_start": seconds_start, | |
| "seconds_total": seconds_input | |
| }] | |
| if negative_prompt: | |
| negative_conditioning = [{ | |
| "video_prompt": [Video_tensors.unsqueeze(0)], | |
| "text_prompt": negative_prompt, | |
| "audio_prompt": audio_tensor.unsqueeze(0), | |
| "seconds_start": seconds_start, | |
| "seconds_total": seconds_total | |
| }] * 1 | |
| else: | |
| negative_conditioning = None | |
| seed = int(seed) | |
| if not use_init: | |
| init_audio = None | |
| input_sample_size = sample_size | |
| def progress_callback(callback_info): | |
| nonlocal preview_images | |
| denoised = callback_info["denoised"] | |
| current_step = callback_info["i"] | |
| sigma = callback_info["sigma"] | |
| if (current_step - 1) % preview_every == 0: | |
| if model.pretransform is not None: | |
| denoised = model.pretransform.decode(denoised) | |
| denoised = rearrange(denoised, "b d n -> d (b n)") | |
| denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() | |
| audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) | |
| preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) | |
| if model_type == "diffusion_cond": | |
| audio = generate_diffusion_cond( | |
| model, | |
| conditioning=conditioning, | |
| negative_conditioning=negative_conditioning, | |
| steps=steps, | |
| cfg_scale=cfg_scale, | |
| batch_size=batch_size, | |
| sample_size=input_sample_size, | |
| sample_rate=sample_rate, | |
| seed=seed, | |
| device=device, | |
| sampler_type=sampler_type, | |
| sigma_min=sigma_min, | |
| sigma_max=sigma_max, | |
| init_audio=init_audio, | |
| init_noise_level=init_noise_level, | |
| mask_args=None, | |
| callback=progress_callback if preview_every is not None else None, | |
| scale_phi=cfg_rescale | |
| ) | |
| audio = rearrange(audio, "b d n -> d (b n)") | |
| samples_10s = 10 * sample_rate | |
| audio = audio[:, :samples_10s] | |
| audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() | |
| output_dir = "demo_result" | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_audio_path = f"{output_dir}/output.wav" | |
| torchaudio.save(output_audio_path, audio, sample_rate) | |
| if actual_video_path: | |
| output_video_path = f"{output_dir}/{os.path.basename(actual_video_path)}" | |
| target_width = 1280 | |
| target_height = 720 | |
| merge_video_audio( | |
| actual_video_path, | |
| output_audio_path, | |
| output_video_path, | |
| seconds_start, | |
| seconds_total | |
| ) | |
| else: | |
| output_video_path = None | |
| del actual_video_path | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return output_video_path, output_audio_path | |
| with gr.Blocks() as interface: | |
| gr.Markdown( | |
| """ | |
| # π§AudioX: Diffusion Transformer for Anything-to-Audio Generation | |
| **[Paper](https://arxiv.org/abs/2503.10522) Β· [Project Page](https://zeyuet.github.io/AudioX/) Β· [Huggingface](https://huggingface.co/HKUSTAudio/AudioX) Β· [GitHub](https://github.com/ZeyueT/AudioX)** | |
| """ | |
| ) | |
| with gr.Tab("Generation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox( | |
| show_label=False, | |
| placeholder="Enter your prompt" | |
| ) | |
| negative_prompt = gr.Textbox( | |
| show_label=False, | |
| placeholder="Negative prompt", | |
| visible=False | |
| ) | |
| video_file = gr.File(label="Upload Video File") | |
| audio_prompt_file = gr.File( | |
| label="Upload Audio Prompt File", | |
| visible=False | |
| ) | |
| audio_prompt_path = gr.Textbox( | |
| label="Audio Prompt Path", | |
| placeholder="Enter audio file path", | |
| visible=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| with gr.Accordion("Video Params", open=False): | |
| seconds_start = gr.Slider( | |
| minimum=0, | |
| maximum=512, | |
| step=1, | |
| value=0, | |
| label="Video Seconds Start" | |
| ) | |
| seconds_total = gr.Slider( | |
| minimum=0, | |
| maximum=10, | |
| step=1, | |
| value=10, | |
| label="Seconds Total", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| with gr.Accordion("Sampler Params", open=False): | |
| steps = gr.Slider( | |
| minimum=1, | |
| maximum=500, | |
| step=1, | |
| value=100, | |
| label="Steps" | |
| ) | |
| preview_every = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| step=1, | |
| value=0, | |
| label="Preview Every" | |
| ) | |
| cfg_scale = gr.Slider( | |
| minimum=0.0, | |
| maximum=25.0, | |
| step=0.1, | |
| value=7.0, | |
| label="CFG Scale" | |
| ) | |
| seed = gr.Textbox( | |
| label="Seed (set to -1 for random seed)", | |
| value="-1" | |
| ) | |
| sampler_type = gr.Dropdown( | |
| choices=[ | |
| "dpmpp-2m-sde", | |
| "dpmpp-3m-sde", | |
| "k-heun", | |
| "k-lms", | |
| "k-dpmpp-2s-ancestral", | |
| "k-dpm-2", | |
| "k-dpm-fast" | |
| ], | |
| label="Sampler Type", | |
| value="dpmpp-3m-sde" | |
| ) | |
| sigma_min = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.01, | |
| value=0.03, | |
| label="Sigma Min" | |
| ) | |
| sigma_max = gr.Slider( | |
| minimum=0.0, | |
| maximum=1000.0, | |
| step=0.1, | |
| value=500, | |
| label="Sigma Max" | |
| ) | |
| cfg_rescale = gr.Slider( | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.01, | |
| value=0.0, | |
| label="CFG Rescale Amount" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| with gr.Accordion("Init Audio", open=False, visible=False): | |
| init_audio_checkbox = gr.Checkbox(label="Use Init Audio") | |
| init_audio_input = gr.Audio(label="Init Audio") | |
| init_noise_level = gr.Slider( | |
| minimum=0.1, | |
| maximum=100.0, | |
| step=0.01, | |
| value=0.1, | |
| label="Init Noise Level" | |
| ) | |
| with gr.Row(): | |
| generate_button = gr.Button("Generate", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| video_output = gr.Video(label="Output Video", interactive=False) | |
| audio_output = gr.Audio(label="Output Audio", interactive=False) | |
| inputs = [ | |
| prompt, | |
| negative_prompt, | |
| video_file, | |
| audio_prompt_file, | |
| audio_prompt_path, | |
| seconds_start, | |
| seconds_total, | |
| cfg_scale, | |
| steps, | |
| preview_every, | |
| seed, | |
| sampler_type, | |
| sigma_min, | |
| sigma_max, | |
| cfg_rescale, | |
| init_audio_checkbox, | |
| init_audio_input, | |
| init_noise_level | |
| ] | |
| generate_button.click( | |
| fn=generate_cond, | |
| inputs=inputs, | |
| outputs=[video_output, audio_output] | |
| ) | |
| gr.Markdown("## Examples") | |
| with gr.Accordion("Click to show examples", open=False): | |
| with gr.Row(): | |
| gr.Markdown("**π Task: Text-to-Audio**") | |
| with gr.Column(scale=1.2): | |
| gr.Markdown("Prompt: *Typing on a keyboard*") | |
| ex1 = gr.Button("Load Example") | |
| with gr.Column(scale=1.2): | |
| gr.Markdown("Prompt: *Ocean waves crashing*") | |
| ex2 = gr.Button("Load Example") | |
| with gr.Column(scale=1.2): | |
| gr.Markdown("Prompt: *Footsteps in snow*") | |
| ex3 = gr.Button("Load Example") | |
| with gr.Row(): | |
| gr.Markdown("**πΆ Task: Text-to-Music**") | |
| with gr.Column(scale=1.2): | |
| gr.Markdown("Prompt: *An orchestral music piece for a fantasy world.*") | |
| ex4 = gr.Button("Load Example") | |
| with gr.Column(scale=1.2): | |
| gr.Markdown("Prompt: *Produce upbeat electronic music for a dance party*") | |
| ex5 = gr.Button("Load Example") | |
| with gr.Column(scale=1.2): | |
| gr.Markdown("Prompt: *A dreamy lo-fi beat with vinyl crackle*") | |
| ex6 = gr.Button("Load Example") | |
| ex1.click(lambda: ["Typing on a keyboard", None, None, None, None, 0, 10, 7.0, 100, 0, "1225575558", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
| ex2.click(lambda: ["Ocean waves crashing", None, None, None, None, 0, 10, 7.0, 100, 0, "3615819170", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
| ex3.click(lambda: ["Footsteps in snow", None, None, None, None, 0, 10, 7.0, 100, 0, "1703896811", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
| ex4.click(lambda: ["An orchestral music piece for a fantasy world.", None, None, None, None, 0, 10, 7.0, 100, 0, "1561898939", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
| ex5.click(lambda: ["Produce upbeat electronic music for a dance party", None, None, None, None, 0, 10, 7.0, 100, 0, "406022999", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
| ex6.click(lambda: ["A dreamy lo-fi beat with vinyl crackle", None, None, None, None, 0, 10, 7.0, 100, 0, "807934770", "dpmpp-3m-sde", 0.03, 500, 0.0, False, None, 0.1], inputs=[], outputs=inputs) | |
| interface.queue(5).launch(server_name="0.0.0.0", server_port=7860, share=True) |