import os import tempfile from pathlib import Path import gradio as gr try: import piano_svsep except ImportError as exc: raise ImportError( "The piano_svsep package is required. Activate the piano-engraving " "conda environment or install dependencies from requirements.txt." ) from exc import partitura as pt import partitura.score as spt import numpy as np import torch import torch_geometric as pyg from piano_svsep.models.pl_models import PLPianoSVSep from piano_svsep.utils.visualization import save_pyg_graph_as_json from piano_svsep.utils import ( hetero_graph_from_note_array, get_vocsep_features, score_graph_to_pyg, HeteroScoreGraph, remove_ties_acros_barlines, get_measurewise_pot_edges, get_pot_chord_edges, get_truth_chords_edges, get_measurewise_truth_edges, assign_voices ) def prepare_score(path_to_score, include_original=True): """ Prepare the score for voice separation. Parameters ---------- path_to_score: str Path to the score file. Partitura can handle different formats such as musicxml, mei, etc. include_original: bool, optional Whether to include the original voice and chord assignments in the graph. Defaults to True. Mostly used for visualization purposes. Returns ------- pg_graph: torch_geometric.data.HeteroData PyG HeteroData object containing the score graph. score: partitura.score.Score Partitura Score object. tie_couples: np.ndarray Array of tied notes. """ # Load the score score = pt.load_score(path_to_score, force_note_ids=True) if len(score) > 1: score = pt.score.Score(pt.score.merge_parts(score.parts)) # Preprocess score for voice separation tie_couples = remove_ties_acros_barlines(score, return_ids=True) # Remove beams for part in score: beams = list(part.iter_all(pt.score.Beam)) for beam in beams: beam_notes = beam.notes for note in beam_notes: note.beam = None part.remove(beam) # Remove rests for part in score: rests = list(part.iter_all(pt.score.Rest)) for rest in rests: part.remove(rest) # Remove Tuplets that contain rests tuplets = list(part.iter_all(pt.score.Tuplet)) for tuplet in tuplets: if isinstance(tuplet.start_note, pt.score.Rest) or isinstance(tuplet.end_note, pt.score.Rest): part.remove(tuplet) # Remove grace notes for part in score: grace_notes = list(part.iter_all(pt.score.GraceNote)) for grace_note in grace_notes: part.remove(grace_note) # Create note array with necessary features note_array = score[0].note_array( include_time_signature=True, include_grace_notes=True, # this is just to check that there are not grace notes left include_staff=True, ) # Get the measure number for each note in the note array mn_map = score[np.array([p._quarter_durations[0] for p in score]).argmax()].measure_number_map note_measures = mn_map(note_array["onset_div"]) # Create heterogeneous graph from note array nodes, edges = hetero_graph_from_note_array(note_array, pot_edge_dist=0) note_features = get_vocsep_features(note_array) hg = HeteroScoreGraph( note_features, edges, name="test_graph", labels=None, note_array=note_array, ) # Get potential edges pot_edges = get_measurewise_pot_edges(note_array, note_measures) pot_chord_edges = get_pot_chord_edges(note_array, hg.get_edges_of_type("onset").numpy()) setattr(hg, "pot_edges", torch.tensor(pot_edges)) setattr(hg, "pot_chord_edges", torch.tensor(pot_chord_edges)) if include_original: # Get truth edges, also called truth when original voice assignment is wrong. truth_chords_edges = get_truth_chords_edges(note_array, pot_chord_edges) polyphonic_truth_edges = get_measurewise_truth_edges(note_array, note_measures) setattr(hg, "truth_chord_edges", torch.tensor(truth_chords_edges).long()) setattr(hg, "truth_edges", torch.tensor(polyphonic_truth_edges).long()) # Convert score graph to PyG graph pg_graph = score_graph_to_pyg(hg) return pg_graph, score, tie_couples def predict_voice(path_to_model, path_to_score, save_path=None): """ Predict the voice assignment for a given score using a pre-trained model. Parameters ---------- path_to_model: str Path to the pre-trained model checkpoint. path_to_score: str Path to the score file. Partitura can handle different formats such as musicxml, mei, etc. save_path: str, optional Path to save the predicted score. If None, the predicted score will be saved in the same directory as the input score with '_pred' appended to the filename. Defaults to None. Returns ------- None Updates are made to the score object and saved to the specified path. """ # Load the model pl_model = PLPianoSVSep.load_from_checkpoint(path_to_model, map_location="cpu", strict=False, weights_only=False) # Prepare the score pg_graph, score, tied_notes = prepare_score(path_to_score) # Batch for compatibility pg_graph = pyg.data.Batch.from_data_list([pg_graph]) # predict the voice assignment with torch.no_grad(): pl_model.module.eval() pred_voices, pred_staff, pg_graph = pl_model.predict_step(pg_graph, return_graph=True) # Partitura processing for visualization part = score[0] save_path = save_path if save_path is not None else os.path.splitext(path_to_score)[0] + "_pred.mei" pg_graph.name = os.path.splitext(os.path.basename(save_path))[0] save_pyg_graph_as_json(pg_graph, ids=part.note_array()["id"], path=os.path.dirname(save_path)) assign_voices(part, pred_voices, pred_staff) tie_notes_over_measures(part, tied_notes) spt.fill_rests(part, measurewise=True) spt.infer_beaming(part) print("Saving corrected score to", save_path) if save_path.endswith(".mei"): pt.save_mei(part,save_path) elif save_path.endswith(".musicxml") or save_path.endswith(".xml"): pt.save_musicxml(part, save_path) else: raise ValueError("Unsupported file format. Please use .mei or .musicxml/.xml") def tie_notes_over_measures(part, tied_notes): for src, dst in tied_notes.T: src_note = None dst_note = None for note in part.notes_tied: if note.id == dst: dst_note = note break for note in part.notes_tied: if note.id == src: src_note = note break if src_note is not None and dst_note is not None: src_note.tie_next = dst_note dst_note.tie_prev = src_note DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "model.ckpt" def run_prediction(score_path: str | None, model_path: str | None): """Run voice prediction and return the path to the generated MEI file.""" if not score_path: return None, "Please upload a score file first." model_to_use = (model_path or "").strip() or str(DEFAULT_MODEL_PATH) if not Path(model_to_use).exists(): return None, f"Model checkpoint not found at: {model_to_use}" input_path = Path(score_path) tmp_dir = Path(tempfile.mkdtemp(prefix="svsep_pred_")) output_path = tmp_dir / f"{input_path.stem}_pred.musicxml" try: predict_voice(model_to_use, str(input_path), str(output_path)) except Exception as exc: # pragma: no cover - shown directly in UI return None, f"Error during prediction: {exc}" return str(output_path), f"Saved prediction to: {output_path}" with gr.Blocks(title="Piano SVSep Voice Separation") as demo: gr.Markdown( "## Piano SVSep Voice Separation\n" "Upload a MusicXML/MEI score and get back an MEI file with predicted voices." ) with gr.Row(): with gr.Column(): score_input = gr.File( label="Score file (.musicxml / .xml / .mei)", file_types=[".musicxml", ".xml", ".mei"], type="filepath", ) model_input = gr.Textbox( label="Model checkpoint path", value=str(DEFAULT_MODEL_PATH), placeholder="Path to model.ckpt", ) run_button = gr.Button("Predict voices") with gr.Column(): output_file = gr.File(label="Predicted MusicXML (download)") status_box = gr.Textbox(label="Status", interactive=False) run_button.click( fn=run_prediction, inputs=[score_input, model_input], outputs=[output_file, status_box], ) if __name__ == "__main__": demo.launch()