piano-engraving / app.py
manoskary's picture
Refactor checkpoint handling to use bundled model and update README instructions
92f382b
import os
import tempfile
from pathlib import Path
import gradio as gr
try:
import piano_svsep
except ImportError as exc:
raise ImportError(
"The piano_svsep package is required. Activate the piano-engraving "
"conda environment or install dependencies from requirements.txt."
) from exc
import partitura as pt
import partitura.score as spt
import numpy as np
import torch
import torch_geometric as pyg
from piano_svsep.models.pl_models import PLPianoSVSep
from piano_svsep.utils.visualization import save_pyg_graph_as_json
from piano_svsep.utils import (
hetero_graph_from_note_array,
get_vocsep_features,
score_graph_to_pyg,
HeteroScoreGraph,
remove_ties_acros_barlines,
get_measurewise_pot_edges,
get_pot_chord_edges,
get_truth_chords_edges,
get_measurewise_truth_edges,
assign_voices
)
def prepare_score(path_to_score, include_original=True):
"""
Prepare the score for voice separation.
Parameters
----------
path_to_score: str
Path to the score file. Partitura can handle different formats such as musicxml, mei, etc.
include_original: bool, optional
Whether to include the original voice and chord assignments in the graph. Defaults to True.
Mostly used for visualization purposes.
Returns
-------
pg_graph: torch_geometric.data.HeteroData
PyG HeteroData object containing the score graph.
score: partitura.score.Score
Partitura Score object.
tie_couples: np.ndarray
Array of tied notes.
"""
# Load the score
score = pt.load_score(path_to_score, force_note_ids=True)
if len(score) > 1:
score = pt.score.Score(pt.score.merge_parts(score.parts))
# Preprocess score for voice separation
tie_couples = remove_ties_acros_barlines(score, return_ids=True)
# Remove beams
for part in score:
beams = list(part.iter_all(pt.score.Beam))
for beam in beams:
beam_notes = beam.notes
for note in beam_notes:
note.beam = None
part.remove(beam)
# Remove rests
for part in score:
rests = list(part.iter_all(pt.score.Rest))
for rest in rests:
part.remove(rest)
# Remove Tuplets that contain rests
tuplets = list(part.iter_all(pt.score.Tuplet))
for tuplet in tuplets:
if isinstance(tuplet.start_note, pt.score.Rest) or isinstance(tuplet.end_note, pt.score.Rest):
part.remove(tuplet)
# Remove grace notes
for part in score:
grace_notes = list(part.iter_all(pt.score.GraceNote))
for grace_note in grace_notes:
part.remove(grace_note)
# Create note array with necessary features
note_array = score[0].note_array(
include_time_signature=True,
include_grace_notes=True, # this is just to check that there are not grace notes left
include_staff=True,
)
# Get the measure number for each note in the note array
mn_map = score[np.array([p._quarter_durations[0] for p in score]).argmax()].measure_number_map
note_measures = mn_map(note_array["onset_div"])
# Create heterogeneous graph from note array
nodes, edges = hetero_graph_from_note_array(note_array, pot_edge_dist=0)
note_features = get_vocsep_features(note_array)
hg = HeteroScoreGraph(
note_features,
edges,
name="test_graph",
labels=None,
note_array=note_array,
)
# Get potential edges
pot_edges = get_measurewise_pot_edges(note_array, note_measures)
pot_chord_edges = get_pot_chord_edges(note_array, hg.get_edges_of_type("onset").numpy())
setattr(hg, "pot_edges", torch.tensor(pot_edges))
setattr(hg, "pot_chord_edges", torch.tensor(pot_chord_edges))
if include_original:
# Get truth edges, also called truth when original voice assignment is wrong.
truth_chords_edges = get_truth_chords_edges(note_array, pot_chord_edges)
polyphonic_truth_edges = get_measurewise_truth_edges(note_array, note_measures)
setattr(hg, "truth_chord_edges", torch.tensor(truth_chords_edges).long())
setattr(hg, "truth_edges", torch.tensor(polyphonic_truth_edges).long())
# Convert score graph to PyG graph
pg_graph = score_graph_to_pyg(hg)
return pg_graph, score, tie_couples
def predict_voice(path_to_model, path_to_score, save_path=None):
"""
Predict the voice assignment for a given score using a pre-trained model.
Parameters
----------
path_to_model: str
Path to the pre-trained model checkpoint.
path_to_score: str
Path to the score file. Partitura can handle different formats such as musicxml, mei, etc.
save_path: str, optional
Path to save the predicted score. If None, the predicted score will be saved in the same directory as the input score with '_pred' appended to the filename. Defaults to None.
Returns
-------
None
Updates are made to the score object and saved to the specified path.
"""
# Load the model
pl_model = PLPianoSVSep.load_from_checkpoint(path_to_model, map_location="cpu", strict=False, weights_only=False)
# Prepare the score
pg_graph, score, tied_notes = prepare_score(path_to_score)
# Batch for compatibility
pg_graph = pyg.data.Batch.from_data_list([pg_graph])
# predict the voice assignment
with torch.no_grad():
pl_model.module.eval()
pred_voices, pred_staff, pg_graph = pl_model.predict_step(pg_graph, return_graph=True)
# Partitura processing for visualization
part = score[0]
save_path = save_path if save_path is not None else os.path.splitext(path_to_score)[0] + "_pred.mei"
pg_graph.name = os.path.splitext(os.path.basename(save_path))[0]
save_pyg_graph_as_json(pg_graph, ids=part.note_array()["id"], path=os.path.dirname(save_path))
assign_voices(part, pred_voices, pred_staff)
tie_notes_over_measures(part, tied_notes)
spt.fill_rests(part, measurewise=True)
spt.infer_beaming(part)
print("Saving corrected score to", save_path)
if save_path.endswith(".mei"):
pt.save_mei(part,save_path)
elif save_path.endswith(".musicxml") or save_path.endswith(".xml"):
pt.save_musicxml(part, save_path)
else:
raise ValueError("Unsupported file format. Please use .mei or .musicxml/.xml")
def tie_notes_over_measures(part, tied_notes):
for src, dst in tied_notes.T:
src_note = None
dst_note = None
for note in part.notes_tied:
if note.id == dst:
dst_note = note
break
for note in part.notes_tied:
if note.id == src:
src_note = note
break
if src_note is not None and dst_note is not None:
src_note.tie_next = dst_note
dst_note.tie_prev = src_note
DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "model.ckpt"
def run_prediction(score_path: str | None, model_path: str | None):
"""Run voice prediction and return the path to the generated MEI file."""
if not score_path:
return None, "Please upload a score file first."
model_to_use = (model_path or "").strip() or str(DEFAULT_MODEL_PATH)
if not Path(model_to_use).exists():
return None, f"Model checkpoint not found at: {model_to_use}"
input_path = Path(score_path)
tmp_dir = Path(tempfile.mkdtemp(prefix="svsep_pred_"))
output_path = tmp_dir / f"{input_path.stem}_pred.musicxml"
try:
predict_voice(model_to_use, str(input_path), str(output_path))
except Exception as exc: # pragma: no cover - shown directly in UI
return None, f"Error during prediction: {exc}"
return str(output_path), f"Saved prediction to: {output_path}"
with gr.Blocks(title="Piano SVSep Voice Separation") as demo:
gr.Markdown(
"## Piano SVSep Voice Separation\n"
"Upload a MusicXML/MEI score and get back an MEI file with predicted voices."
)
with gr.Row():
with gr.Column():
score_input = gr.File(
label="Score file (.musicxml / .xml / .mei)",
file_types=[".musicxml", ".xml", ".mei"],
type="filepath",
)
model_input = gr.Textbox(
label="Model checkpoint path",
value=str(DEFAULT_MODEL_PATH),
placeholder="Path to model.ckpt",
)
run_button = gr.Button("Predict voices")
with gr.Column():
output_file = gr.File(label="Predicted MusicXML (download)")
status_box = gr.Textbox(label="Status", interactive=False)
run_button.click(
fn=run_prediction,
inputs=[score_input, model_input],
outputs=[output_file, status_box],
)
if __name__ == "__main__":
demo.launch()