mie237's picture
增加自动模式错误处理和用户提示信息
3b2e714
#!/usr/bin/env python3
import os
import json
import gzip
import requests
import gradio as gr
from openai import OpenAI
from prompt_refiner_prompt import prompt_refiner_prompt as _PROMPT_REFINER_SYSTEM
# API Endpoints — required environment variables, no defaults
def _require_env(name: str) -> str:
value = os.environ.get(name)
if not value:
raise EnvironmentError(
f"Required environment variable '{name}' is not set. "
"Please set it before launching the app."
)
return value
AUDIOGEN_API_URL = _require_env("AUDIOGEN_API_URL")
LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "")
CLAW_API_URL = os.environ.get("CLAW_API_URL", "")
XI_API_BASE_URL = "https://api.xi-ai.cn/v1"
PROMPT_REFINER_MAX_RETRIES = 3
PROMPT_REFINER_MODE = os.environ.get("PROMPT_REFINER_MODE", "xi_api")
# Special token order and mapping
SPECIAL_TOKEN_ORDER = ["caption", "speech", "sfx", "music", "env", "asr"]
SPECIAL_TOKEN_MAP = {
"caption": "<|caption|>",
"speech": "<|speech|>",
"sfx": "<|sfx|>",
"music": "<|music|>",
"env": "<|env|>",
"asr": "<|asr|>",
}
def build_structured_prompt(
caption="", speech="", sfx="", music="", env="", asr=""
):
"""Assemble a structured prompt string from individual token fields.
Only tokens with non-empty values are included."""
fields = {
"caption": caption,
"speech": speech,
"sfx": sfx,
"music": music,
"env": env,
"asr": asr,
}
parts = []
for token in SPECIAL_TOKEN_ORDER:
value = (fields[token] or "").strip()
if value:
parts.append(f"{SPECIAL_TOKEN_MAP[token]} {value}")
return " ".join(parts)
def call_audiogen(structured_prompt):
"""POST the structured prompt to AudioGen API, save and return the WAV path."""
if not structured_prompt.strip():
return None, "Error: Prompt is empty"
try:
response = requests.post(
AUDIOGEN_API_URL,
headers={"Content-Type": "application/json"},
json={"text": structured_prompt},
timeout=120,
)
response.raise_for_status()
os.makedirs("./outputs", exist_ok=True)
output_path = "./outputs/audiogen_output.wav"
with open(output_path, "wb") as f:
f.write(response.content)
return output_path, "Generation successful!"
except requests.exceptions.ConnectionError:
return None, "Error: Cannot connect to AudioGen API. Please check the service."
except requests.exceptions.HTTPError as e:
return None, f"Error: HTTP {e.response.status_code} - {e.response.reason}"
except requests.exceptions.Timeout:
return None, "Error: Request timed out."
except Exception as e:
return None, f"Error: {str(e)}"
def _parse_and_validate(raw_content: str, attempt: int):
"""Parse JSON and validate required 'caption' field. Returns (dict|None, error_str)."""
try:
parsed = json.loads(raw_content)
except json.JSONDecodeError as e:
return None, f"Invalid JSON on attempt {attempt}: {e}"
normalized = {
k.lower(): v
for k, v in parsed.items() if v is not None and str(v).strip()
}
if not normalized.get("caption"):
return None, f"Missing required 'Caption' field on attempt {attempt}."
return normalized, None
def _decode_claw_response_json(response: requests.Response) -> dict:
"""Decode CLAW response robustly, including mis-labeled gzip responses."""
raw_bytes = response.raw.read(decode_content=False)
candidates = []
# 1) Treat as plain utf-8 text first (some responses are plain text but mislabeled)
candidates.append(raw_bytes.decode("utf-8", errors="replace"))
# 2) Try gzip decode as fallback when content-encoding is incorrect/mixed
try:
candidates.append(
gzip.decompress(raw_bytes).decode("utf-8", errors="replace")
)
except Exception:
pass
last_err = None
for text in candidates:
try:
return json.loads(text)
except Exception as e:
last_err = e
raise ValueError(f"Unable to decode CLAW JSON response: {last_err}")
def _call_prompt_refiner_claw(user_input: str, max_retries: int) -> dict:
"""Call Prompt Refiner via CLAW endpoint (no auth required).
Sends the full prompt template with user input substituted as plain text.
"""
# Substitute user input into the prompt template
full_prompt = _PROMPT_REFINER_SYSTEM.replace("{{user_input}}",
user_input).strip()
last_error = None
for attempt in range(1, max_retries + 1):
try:
response = requests.post(
CLAW_API_URL,
headers={"Content-Type": "text/plain"},
data=full_prompt.encode("utf-8"),
timeout=60,
stream=True,
)
response.raise_for_status()
# Response has same structure as OpenAI: choices[0].message.content
resp_json = _decode_claw_response_json(response)
raw_content = resp_json["choices"][0]["message"]["content"]
result, err = _parse_and_validate(raw_content, attempt)
if err:
last_error = err
continue
return result
except requests.exceptions.HTTPError as e:
code = e.response.status_code
raise RuntimeError(
f"CLAW API HTTP error {code}: {e.response.reason}"
) from e
except requests.exceptions.ConnectionError as e:
raise RuntimeError(f"CLAW API connection error: {e}") from e
except requests.exceptions.Timeout:
last_error = f"CLAW API timed out on attempt {attempt}."
except Exception as e:
last_error = f"CLAW API error on attempt {attempt}: {e}"
raise RuntimeError(
f"Prompt Refiner (claw) failed after {max_retries} attempt(s). "
f"Last error: {last_error}"
)
def _call_prompt_refiner_openai(user_input: str, max_retries: int) -> dict:
"""Call Prompt Refiner via OpenAI-compatible chat completions endpoint."""
api_key = os.environ.get("API_KEY")
model_name = os.environ.get("MODEL_NAME")
if not api_key:
raise EnvironmentError(
"API_KEY environment variable is not set. "
"Please set it before using Auto Mode (openai mode)."
)
if not model_name:
raise EnvironmentError(
"MODEL_NAME environment variable is not set. "
"Please set it before using Auto Mode (openai mode)."
)
if not LLM_BASE_URL:
raise EnvironmentError(
"LLM_BASE_URL environment variable is not set. "
"Please set it before using Auto Mode (openai mode)."
)
client = OpenAI(api_key=api_key, base_url=LLM_BASE_URL)
system_content = _PROMPT_REFINER_SYSTEM.replace("{{user_input}}",
"").strip()
last_error = None
for attempt in range(1, max_retries + 1):
try:
completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "system",
"content": system_content
},
{
"role": "user",
"content": user_input
},
],
max_completion_tokens=1024,
response_format={"type": "json_object"},
)
raw_content = completion.choices[0].message.content
result, err = _parse_and_validate(raw_content, attempt)
if err:
last_error = err
continue
return result
except EnvironmentError:
raise
except Exception as e:
err_str = str(e).lower()
if any(
kw in err_str for kw in
("authentication", "api_key", "unauthorized", "403", "401")
):
raise RuntimeError(f"Prompt Refiner auth error: {e}") from e
last_error = f"API error on attempt {attempt}: {e}"
raise RuntimeError(
f"Prompt Refiner (openai) failed after {max_retries} attempt(s). "
f"Last error: {last_error}"
)
def _call_prompt_refiner_xi_api(user_input: str, max_retries: int) -> dict:
"""Call Prompt Refiner via XI API chat completions endpoint."""
api_key = os.environ.get("XI_API_KEY")
model_name = os.environ.get("XI_MODEL_NAME", "deepseek-v4-flash")
if not api_key:
raise EnvironmentError(
"XI_API_KEY environment variable is not set. "
"Please set it before using Auto Mode (xi_api mode)."
)
client = OpenAI(api_key=api_key, base_url=XI_API_BASE_URL)
system_content = _PROMPT_REFINER_SYSTEM.replace("{{user_input}}",
"").strip()
last_error = None
for attempt in range(1, max_retries + 1):
try:
completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "system",
"content": system_content
},
{
"role": "user",
"content": user_input
},
],
)
raw_content = completion.choices[0].message.content
result, err = _parse_and_validate(raw_content, attempt)
if err:
last_error = err
continue
return result
except EnvironmentError:
raise
except Exception as e:
err_str = str(e).lower()
if any(
kw in err_str for kw in
("authentication", "api_key", "unauthorized", "403", "401")
):
raise RuntimeError(f"Prompt Refiner auth error: {e}") from e
last_error = f"XI API error on attempt {attempt}: {e}"
raise RuntimeError(
f"Prompt Refiner (xi_api) failed after {max_retries} attempt(s). "
f"Last error: {last_error}"
)
def call_prompt_refiner(user_input, max_retries=PROMPT_REFINER_MAX_RETRIES):
"""Dispatch to the configured Prompt Refiner backend.
Mode is controlled by the PROMPT_REFINER_MODE environment variable:
'xi_api' — XI API chat completions endpoint (default)
'claw' — CLAW plain-text endpoint, no auth required
'openai' — OpenAI-compatible chat completions endpoint
"""
mode = PROMPT_REFINER_MODE.lower()
if mode == "xi_api":
return _call_prompt_refiner_xi_api(user_input, max_retries)
elif mode == "openai":
return _call_prompt_refiner_openai(user_input, max_retries)
elif mode == "claw":
return _call_prompt_refiner_claw(user_input, max_retries)
else:
raise ValueError(
f"Unknown PROMPT_REFINER_MODE '{mode}'. "
"Valid values: 'xi_api' (default), 'claw', 'openai'."
)
def build_caption_from_refined(refined: dict) -> str:
"""Build the full structured prompt string from a refined dict.
This is a convenience wrapper around build_structured_prompt."""
return build_structured_prompt(
caption=refined.get("caption", ""),
speech=refined.get("speech", ""),
sfx=refined.get("sfx", ""),
music=refined.get("music", ""),
env=refined.get("env", ""),
asr=refined.get("asr", ""),
)
def _build_auto_mode_llm_error(message: str) -> str:
"""Return a clear UI-facing error message for Prompt Refiner failures."""
return (
"Auto Mode is currently unavailable because the Prompt Refiner LLM API "
f"failed: {message} "
"You can still try Manual Mode, which may remain available because it "
"does not depend on the Prompt Refiner. "
"If the issue continues, contact jiahaomei@sjtu.edu.cn."
)
def auto_generate(caption, progress=gr.Progress()):
"""Mode 1: Call Prompt Refiner -> build structured prompt -> call AudioGen."""
if not (caption or "").strip():
return None, "", "Error: Please enter a description."
progress(0.1, desc="Calling Prompt Refiner...")
try:
refined = call_prompt_refiner(caption)
except EnvironmentError as e:
return None, "", _build_auto_mode_llm_error(f"Configuration error. {e}")
except RuntimeError as e:
return None, "", _build_auto_mode_llm_error(str(e))
except Exception as e:
return None, "", _build_auto_mode_llm_error(
f"Unexpected error while calling Prompt Refiner. {e}"
)
progress(0.4, desc="Building structured prompt...")
structured_prompt = build_caption_from_refined(refined)
progress(0.6, desc="Generating audio...")
audio_path, status = call_audiogen(structured_prompt)
progress(1.0)
return audio_path, structured_prompt, status
def manual_generate(
caption, speech, sfx, music, env, asr, progress=gr.Progress()
):
"""Mode 2: Build structured prompt from individual fields -> call AudioGen."""
if not (caption or "").strip():
return None, "", "Error: Caption is required."
progress(0.2, desc="Building structured prompt...")
structured_prompt = build_structured_prompt(
caption=caption,
speech=speech,
sfx=sfx,
music=music,
env=env,
asr=asr,
)
progress(0.5, desc="Generating audio...")
audio_path, status = call_audiogen(structured_prompt)
progress(1.0)
return audio_path, structured_prompt, status
# Custom CSS
custom_css = """
.prompt-preview textarea {
font-family: monospace !important;
font-size: 12px !important;
}
.mode-radio label {
font-weight: 600 !important;
}
.banner-warning {
padding: 12px 16px;
background: rgba(255, 193, 7, 0.12);
border: 2px solid #d4920a;
border-radius: 6px;
margin-bottom: 12px;
font-size: 14px;
line-height: 1.9;
}
.dark .banner-warning {
background: rgba(255, 193, 7, 0.07) !important;
border-color: #c8860a !important;
}
.banner-warning a {
color: #1a73e8;
}
.dark .banner-warning a {
color: #7aafff !important;
}
"""
# ── Mode switching helper ──────────────────────────────────────────────────────
def switch_mode(mode):
is_auto = mode == "🤖 Auto Mode"
return gr.update(visible=is_auto), gr.update(visible=not is_auto)
# ── Gradio UI ──────────────────────────────────────────────────────────────────
with gr.Blocks(
title="Dasheng AudioGen Demo",
theme=gr.themes.Soft(),
css=custom_css,
) as demo:
gr.Markdown("# 🔊 Dasheng AudioGen Demo")
gr.Markdown("Developed by SJTU X-LANCE & Xiaomi LLM Plus")
gr.Markdown(
"支持结构化 Prompt 的混合音频生成,可用自然语言描述场景(Auto mode)或逐轨道填写(Manual mode),一次生成包含音乐、可理解人声和音效的完整音频。。 \n"
"Structured-prompt mixed audio generation that lets you describe a scene in natural language (Auto mode) or specify tracks manually (Manual mode), producing a complete audio clip with music, intelligible speech, and sound effects in one pass."
)
# Mode selector
mode_radio = gr.Radio(
choices=["🤖 Auto Mode", "✏️ Manual Mode"],
value="🤖 Auto Mode",
label="Generation Mode",
interactive=True,
elem_classes=["mode-radio"],
)
# ── Auto Mode section ──────────────────────────────────────────────────────
with gr.Column(visible=True) as auto_section:
gr.Markdown(
"## 🤖 Auto Mode \n"
"If Auto Mode is unavailable, contact jiahaomei@sjtu.edu.cn."
)
gr.HTML(
'<div class="banner-warning">'
"⚠️ <strong>一次生成包含音乐、可理解人声和音效的完整音频。若生成音频质量较差或 Speech 内容不完整,可多尝试几次。</strong><br>"
"⚠️ <strong>Producing a complete audio clip with music, intelligible speech, and sound effects in one pass. If the generated audio quality is poor or speech content is incomplete, please try generating again.</strong>"
"<br><br>"
"💬 你可以输入任意语言的音频描述,Prompt Refiner 会进行自动转换。目前 Speech 合成只支持英文,多语言支持即将上线。<br>"
"💬 You can enter audio descriptions in any language — the Prompt Refiner will automatically convert them. "
"Currently, speech synthesis only supports English. Multi-language support coming soon."
"<br><br>"
"🌐 Web Demo: "
'<a href="https://nieeim.github.io/Dasheng-AudioGen-Web/" target="_blank">DashengAudioGen Web Demo</a><br>'
"📦 GitHub Repo: "
'<a href="https://github.com/NieeiM/Dasheng-Audiogen" target="_blank">DashengAudioGen GitHub Repository</a>'
"<br>"
"🦞 该模型也提供 OpenClaw Skill 调用 / Also available as an OpenClaw Skill: "
'<a href="https://clawhub.ai/jimbozhang/midasheng-audio-generate" target="_blank">OpenClaw Skill 主页 / OpenClaw Skill Page</a>'
"</div>"
)
gr.Markdown(
"输入整体音频描述,系统将调用 **Prompt Refiner** 自动转换为结构化 Prompt,再通过 **DashengAudioGen** 生成音频。 \n"
"Enter an overall audio description. The system will call the **Prompt Refiner** to convert it into a "
"structured prompt, then generate audio via **DashengAudioGen**."
)
with gr.Row():
with gr.Column():
auto_caption = gr.Textbox(
label="整体音频描述 / Overall Audio Description",
placeholder=(
'e.g., A train station broadcast says, '
'"Train G128 is arriving on platform three, '
'please stand behind the yellow line." '
'with warning beeps, light orchestral bed, and station ambience.'
),
lines=4,
)
auto_button = gr.Button("Generate", variant="primary")
with gr.Column():
auto_audio = gr.Audio(
label="生成音频 / Generated Audio", type="filepath"
)
auto_prompt_preview = gr.Textbox(
label="结构化 Prompt 预览 / Structured Prompt (Preview)",
lines=4,
interactive=False,
elem_classes=["prompt-preview"],
)
auto_status = gr.Textbox(label="状态 / Status")
gr.Examples(
examples=[
[
'A game announcer shouts, "Final round begins now, give it everything you have!" with crowd cheers, drum hits, and stadium ambience.',
],
[
'A café barista says, "Your caramel latte is ready at the counter." while soft jazz plays, cups clink, and indoor café ambience continues.',
],
[
'A train station broadcast says, "Train G128 is arriving on platform three, please stand behind the yellow line." with warning beeps, light orchestral bed, and station ambience.',
],
[
'A radio host announces traffic updates over upbeat pop music with city street ambience.'
],
['安静的午后咖啡馆,一男一女在讨论天气'],
[
'列车员使用英文报站,说“Train G128 is arriving on platform three, please stand behind the yellow line.”'
],
],
inputs=[auto_caption],
label="Examples",
)
auto_button.click(
fn=auto_generate,
inputs=[auto_caption],
outputs=[auto_audio, auto_prompt_preview, auto_status],
)
# ── Manual Mode section ────────────────────────────────────────────────────
with gr.Column(visible=False) as manual_section:
gr.Markdown(
"## ✏️ Manual Mode \n"
"If Manual Mode is unavailable, contact jiahaomei@sjtu.edu.cn."
)
gr.HTML(
'<div class="banner-warning">'
"⚠️ <strong>一次生成包含音乐、可理解人声和音效的完整音频。若生成音频质量较差或 Speech 内容不完整,可多尝试几次。</strong><br>"
"⚠️ <strong>Producing a complete audio clip with music, intelligible speech, and sound effects in one pass. If the generated audio quality is poor or speech content is incomplete, please try generating again.</strong>"
"<br><br>"
"🔤 目前 Speech 合成只支持英文,多语言支持即将上线。<br>"
"🔤 Currently, speech synthesis only supports English. Multi-language support coming soon."
"<br><br>"
"🌐 Web Demo: "
'<a href="https://nieeim.github.io/Dasheng-AudioGen-Web/" target="_blank">DashengAudioGen Web Demo</a><br>'
"📦 GitHub Repo: "
'<a href="https://github.com/NieeiM/Dasheng-Audiogen" target="_blank">DashengAudioGen GitHub Repository</a>'
"<br>"
"🔗 该模型也提供 OpenClaw Skill 调用 / Also available as an OpenClaw Skill: "
'<a href="https://clawhub.ai/jimbozhang/midasheng-audio-generate" target="_blank">OpenClaw Skill 主页 / OpenClaw Skill Page</a>'
"</div>"
)
gr.Markdown(
"逐轨道填写音频元素,仅 **Caption** 为必填项,其余字段均为可选。 \n"
"Fill in each track individually. Only **Caption** is required; all other fields are optional."
)
with gr.Row():
with gr.Column():
man_caption = gr.Textbox(
label="Caption — 整体描述 / Overall Description *",
placeholder=(
'e.g., A train station broadcast says, '
'"Train G128 is arriving on platform three, '
'please stand behind the yellow line." '
'with warning beeps and station ambience.'
),
lines=3,
)
man_speech = gr.Textbox(
label="Speech — 说话人身份与风格 / Speaker Identity & Style",
placeholder="e.g., female announcer, calm and clear tone",
lines=1,
)
man_asr = gr.Textbox(
label="ASR — 语音文字 / Speech Transcript",
placeholder=(
"e.g., Train G128 is arriving on platform three, "
"please stand behind the yellow line."
),
lines=2,
)
man_sfx = gr.Textbox(
label="SFX — 音效 / Sound Effects",
placeholder="e.g., warning beeps",
lines=1,
)
man_music = gr.Textbox(
label="Music — 背景音乐 / Background Music",
placeholder="e.g., light orchestral underscore",
lines=1,
)
man_env = gr.Textbox(
label="ENV — 环境音 / Environmental & Ambient Sound",
placeholder="e.g., train station ambience",
lines=1,
)
man_button = gr.Button("Generate Audio", variant="primary")
with gr.Column():
man_audio = gr.Audio(
label="生成音频 / Generated Audio", type="filepath"
)
man_prompt_preview = gr.Textbox(
label="结构化 Prompt 预览 / Structured Prompt (Preview)",
lines=5,
interactive=False,
elem_classes=["prompt-preview"],
)
man_status = gr.Textbox(label="状态 / Status")
man_button.click(
fn=manual_generate,
inputs=[
man_caption, man_speech, man_sfx, man_music, man_env, man_asr
],
outputs=[man_audio, man_prompt_preview, man_status],
)
gr.Examples(
examples=[
[
"A game announcer shouts with crowd cheers, drum hits, and indoor stadium ambience.",
"excited male game announcer",
"crowd cheers and impacts",
"energetic drum rhythm",
"indoor stadium ambience",
"Final round begins now, give it everything you have!",
],
[
"A café barista makes an announcement while soft jazz plays, cups clink, and indoor café ambience continues.",
"barista announcement",
"cup and spoon clinks",
"soft jazz trio",
"indoor cafe ambience",
"Your caramel latte is ready at the counter.",
],
[
"A train station broadcast makes an announcement with warning beeps, light orchestral bed, and station ambience.",
"station public announcement",
"warning beeps",
"light orchestral underscore",
"train station ambience",
"Train G128 is arriving on platform three, please stand behind the yellow line.",
],
],
inputs=[
man_caption, man_speech, man_sfx, man_music, man_env, man_asr
],
label="Examples",
)
gr.Markdown(
r"""
---
### Special Token 说明 / Special Token Reference
| Token | 字段 / Field | 说明 / Description |
|-------|------------|-------------------|
| `<\|caption\|>` | Caption | 整体音频场景描述(必填)/ Overall audio scene description (required) |
| `<\|speech\|>` | Speech | 说话人身份与风格 / Speaker identity & speaking style |
| `<\|asr\|>` | ASR | 语音文字内容 / Actual transcript of speech content |
| `<\|sfx\|>` | SFX | 音效描述 / Sound effects present in the audio |
| `<\|music\|>` | Music | 背景音乐描述 / Background music description |
| `<\|env\|>` | ENV | 环境音 / Environmental & ambient sound |
"""
)
# Mode switching event
mode_radio.change(
fn=switch_mode,
inputs=[mode_radio],
outputs=[auto_section, manual_section],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)