File size: 8,964 Bytes
08787db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92f382b
08787db
 
 
 
 
 
 
92f382b
ab8d2ed
08787db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import os
import tempfile
from pathlib import Path

import gradio as gr
try:
    import piano_svsep
except ImportError as exc:
    raise ImportError(
        "The piano_svsep package is required. Activate the piano-engraving "
        "conda environment or install dependencies from requirements.txt."
    ) from exc
import partitura as pt
import partitura.score as spt
import numpy as np
import torch
import torch_geometric as pyg
from piano_svsep.models.pl_models import PLPianoSVSep
from piano_svsep.utils.visualization import save_pyg_graph_as_json
from piano_svsep.utils import (
    hetero_graph_from_note_array,
    get_vocsep_features,
    score_graph_to_pyg,
    HeteroScoreGraph,
    remove_ties_acros_barlines,
    get_measurewise_pot_edges,
    get_pot_chord_edges,
    get_truth_chords_edges,
    get_measurewise_truth_edges,
    assign_voices
    )


def prepare_score(path_to_score, include_original=True):
    """
    Prepare the score for voice separation.

    Parameters
    ----------
    path_to_score: str
        Path to the score file. Partitura can handle different formats such as musicxml, mei, etc.
    include_original: bool, optional
        Whether to include the original voice and chord assignments in the graph. Defaults to True.
        Mostly used for visualization purposes.
    Returns
    -------
        pg_graph: torch_geometric.data.HeteroData
            PyG HeteroData object containing the score graph.
        score: partitura.score.Score
            Partitura Score object.
        tie_couples: np.ndarray
            Array of tied notes.
    """
    # Load the score
    score = pt.load_score(path_to_score, force_note_ids=True)
    if len(score) > 1:
        score = pt.score.Score(pt.score.merge_parts(score.parts))

    # Preprocess score for voice separation
    tie_couples = remove_ties_acros_barlines(score, return_ids=True)

    # Remove beams
    for part in score:
        beams = list(part.iter_all(pt.score.Beam))
        for beam in beams:
            beam_notes = beam.notes
            for note in beam_notes:
                note.beam = None
            part.remove(beam)

    # Remove rests
    for part in score:
        rests = list(part.iter_all(pt.score.Rest))
        for rest in rests:
            part.remove(rest)
        # Remove Tuplets that contain rests
        tuplets = list(part.iter_all(pt.score.Tuplet))
        for tuplet in tuplets:
            if isinstance(tuplet.start_note, pt.score.Rest) or isinstance(tuplet.end_note, pt.score.Rest):
                part.remove(tuplet)

    # Remove grace notes
    for part in score:
        grace_notes = list(part.iter_all(pt.score.GraceNote))
        for grace_note in grace_notes:
            part.remove(grace_note)

    # Create note array with necessary features
    note_array = score[0].note_array(
        include_time_signature=True,
        include_grace_notes=True, # this is just to check that there are not grace notes left
        include_staff=True,
    )

    # Get the measure number for each note in the note array
    mn_map = score[np.array([p._quarter_durations[0] for p in score]).argmax()].measure_number_map
    note_measures = mn_map(note_array["onset_div"])

    # Create heterogeneous graph from note array
    nodes, edges = hetero_graph_from_note_array(note_array, pot_edge_dist=0)
    note_features = get_vocsep_features(note_array)
    hg = HeteroScoreGraph(
        note_features,
        edges,
        name="test_graph",
        labels=None,
        note_array=note_array,
    )

    # Get potential edges
    pot_edges = get_measurewise_pot_edges(note_array, note_measures)
    pot_chord_edges = get_pot_chord_edges(note_array, hg.get_edges_of_type("onset").numpy())
    setattr(hg, "pot_edges", torch.tensor(pot_edges))
    setattr(hg, "pot_chord_edges", torch.tensor(pot_chord_edges))

    if include_original:
        # Get truth edges, also called truth when original voice assignment is wrong.
        truth_chords_edges = get_truth_chords_edges(note_array, pot_chord_edges)
        polyphonic_truth_edges = get_measurewise_truth_edges(note_array, note_measures)
        setattr(hg, "truth_chord_edges", torch.tensor(truth_chords_edges).long())
        setattr(hg, "truth_edges", torch.tensor(polyphonic_truth_edges).long())

    # Convert score graph to PyG graph
    pg_graph = score_graph_to_pyg(hg)

    return pg_graph, score, tie_couples


def predict_voice(path_to_model, path_to_score, save_path=None):
    """
    Predict the voice assignment for a given score using a pre-trained model.

    Parameters
    ----------
    path_to_model: str
        Path to the pre-trained model checkpoint.
    path_to_score: str
        Path to the score file. Partitura can handle different formats such as musicxml, mei, etc.
    save_path: str, optional
        Path to save the predicted score. If None, the predicted score will be saved in the same directory as the input score with '_pred' appended to the filename. Defaults to None.

    Returns
    -------
    None
        Updates are made to the score object and saved to the specified path.
    """
    # Load the model
    pl_model = PLPianoSVSep.load_from_checkpoint(path_to_model, map_location="cpu", strict=False, weights_only=False)
    # Prepare the score
    pg_graph, score, tied_notes = prepare_score(path_to_score)
    # Batch for compatibility
    pg_graph = pyg.data.Batch.from_data_list([pg_graph])
    # predict the voice assignment
    with torch.no_grad():
        pl_model.module.eval()
        pred_voices, pred_staff, pg_graph = pl_model.predict_step(pg_graph, return_graph=True)
    # Partitura processing for visualization
    part = score[0]
    save_path = save_path if save_path is not None else os.path.splitext(path_to_score)[0] + "_pred.mei"
    pg_graph.name = os.path.splitext(os.path.basename(save_path))[0]
    save_pyg_graph_as_json(pg_graph, ids=part.note_array()["id"], path=os.path.dirname(save_path))
    assign_voices(part, pred_voices, pred_staff)
    tie_notes_over_measures(part, tied_notes)
    spt.fill_rests(part, measurewise=True)
    spt.infer_beaming(part)
    print("Saving corrected score to", save_path)
    if save_path.endswith(".mei"):
        pt.save_mei(part,save_path)
    elif save_path.endswith(".musicxml") or save_path.endswith(".xml"):
        pt.save_musicxml(part, save_path)
    else:
        raise ValueError("Unsupported file format. Please use .mei or .musicxml/.xml")


def tie_notes_over_measures(part, tied_notes):
    for src, dst in tied_notes.T:
        src_note = None
        dst_note = None
        for note in part.notes_tied:
            if note.id == dst:
                dst_note = note
                break
        for note in part.notes_tied:
            if note.id == src:
                src_note = note
                break
        if src_note is not None and dst_note is not None:
            src_note.tie_next = dst_note
            dst_note.tie_prev = src_note



DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "model.ckpt"


def run_prediction(score_path: str | None, model_path: str | None):
    """Run voice prediction and return the path to the generated MEI file."""
    if not score_path:
        return None, "Please upload a score file first."

    model_to_use = (model_path or "").strip() or str(DEFAULT_MODEL_PATH)

    if not Path(model_to_use).exists():
        return None, f"Model checkpoint not found at: {model_to_use}"

    input_path = Path(score_path)
    tmp_dir = Path(tempfile.mkdtemp(prefix="svsep_pred_"))
    output_path = tmp_dir / f"{input_path.stem}_pred.musicxml"

    try:
        predict_voice(model_to_use, str(input_path), str(output_path))
    except Exception as exc:  # pragma: no cover - shown directly in UI
        return None, f"Error during prediction: {exc}"

    return str(output_path), f"Saved prediction to: {output_path}"


with gr.Blocks(title="Piano SVSep Voice Separation") as demo:
    gr.Markdown(
        "## Piano SVSep Voice Separation\n"
        "Upload a MusicXML/MEI score and get back an MEI file with predicted voices."
    )

    with gr.Row():
        with gr.Column():
            score_input = gr.File(
                label="Score file (.musicxml / .xml / .mei)",
                file_types=[".musicxml", ".xml", ".mei"],
                type="filepath",
            )
            model_input = gr.Textbox(
                label="Model checkpoint path",
                value=str(DEFAULT_MODEL_PATH),
                placeholder="Path to model.ckpt",
            )
            run_button = gr.Button("Predict voices")

        with gr.Column():
            output_file = gr.File(label="Predicted MusicXML (download)")
            status_box = gr.Textbox(label="Status", interactive=False)

    run_button.click(
        fn=run_prediction,
        inputs=[score_input, model_input],
        outputs=[output_file, status_box],
    )


if __name__ == "__main__":
    demo.launch()