AIvry commited on
Commit
b759ccc
·
verified ·
1 Parent(s): 424791f

Upload 11 files

Browse files
Files changed (10) hide show
  1. app.py +35 -17
  2. argshield.py +1 -2
  3. audio.py +49 -32
  4. config.py +35 -35
  5. distortions.py +6 -2
  6. engine.py +206 -185
  7. main.py +8 -0
  8. metrics.py +64 -12
  9. models.py +43 -18
  10. utils.py +27 -90
app.py CHANGED
@@ -66,6 +66,8 @@ def process_audio_files(zip_file, model_name, layer, alpha):
66
  return None, "No reference WAV files found"
67
  if len(out_files) == 0:
68
  return None, "No output WAV files found"
 
 
69
 
70
  # Create manifest
71
  manifest = [{
@@ -92,8 +94,8 @@ def process_audio_files(zip_file, model_name, layer, alpha):
92
  }
93
  layer_final = layer if layer is not None else model_defaults.get(model_name, 12)
94
 
95
- # Check GPU availability
96
- max_gpus = 1 if torch.cuda.is_available() else 0
97
 
98
  # Run experiment
99
  results_dir = compute_mapss_measures(
@@ -103,7 +105,7 @@ def process_audio_files(zip_file, model_name, layer, alpha):
103
  alpha=alpha,
104
  verbose=True,
105
  max_gpus=max_gpus,
106
- add_ci=False
107
  )
108
 
109
  # Create output ZIP at a fixed location
@@ -121,7 +123,7 @@ def process_audio_files(zip_file, model_name, layer, alpha):
121
  files_added += 1
122
 
123
  if output_zip.exists() and files_added > 0:
124
- return str(output_zip), f"Processing completed! Created ZIP with {files_added} files."
125
  else:
126
  return None, f"Processing completed but no output files were generated. Check if embeddings were computed."
127
 
@@ -143,6 +145,13 @@ def create_interface():
143
  - **Perceptual Matching (PM)**: Measures how closely an output perceptually aligns with its reference. Range: 0-1, higher is better.
144
  - **Perceptual Similarity (PS)**: Measures how well an output is separated from its interfering references. Range: 0-1, higher is better.
145
 
 
 
 
 
 
 
 
146
  ## Input Format
147
 
148
  Upload a ZIP file containing:
@@ -152,9 +161,9 @@ def create_interface():
152
  │ ├── speaker1.wav
153
  │ ├── speaker2.wav
154
  │ └── ...
155
- └── outputs/ # Separated outputs from your algorithm
156
- ├── separated1.wav
157
- ├── separated2.wav
158
  └── ...
159
  ```
160
 
@@ -162,16 +171,22 @@ def create_interface():
162
  - Format: .wav files
163
  - Sample rate: Any (automatically resampled to 16kHz)
164
  - Channels: Mono or stereo (converted to mono)
165
- - Number of files: Equal number of references and outputs
 
166
 
167
  ## Output Format
168
 
169
  The tool generates a ZIP file containing:
170
- - `ps_scores_{model}.csv`: PS scores for each source
171
- - `pm_scores_{model}.csv`: PM scores for each source
172
  - `params.json`: Parameters used
173
  - `manifest_canonical.json`: File mapping and processing details
174
 
 
 
 
 
 
175
  ## Available Models
176
 
177
  | Model | Description | Default Layer | Use Case |
@@ -179,10 +194,10 @@ def create_interface():
179
  | `raw` | Raw waveform features | N/A | Baseline comparison |
180
  | `wavlm` | WavLM Large | 24 | Strong performance |
181
  | `wav2vec2` | Wav2Vec2 Large | 24 | Best overall performance |
182
- | `hubert` | HuBERT Large | 24 | |
183
- | `wavlm_base` | WavLM Base | 12 | |
184
  | `wav2vec2_base` | Wav2Vec2 Base | 12 | Faster, good quality |
185
- | `hubert_base` | HuBERT Base | 12 | |
186
  | `wav2vec2_xlsr` | Wav2Vec2 XLSR-53 | 24 | Multilingual |
187
 
188
  ## Parameters
@@ -193,6 +208,13 @@ def create_interface():
193
  - 0.0 = No normalization
194
  - 1.0 = Full normalization (recommended)
195
 
 
 
 
 
 
 
 
196
  ## Citation
197
 
198
  If you use MAPSS, please cite:
@@ -207,10 +229,6 @@ def create_interface():
207
  }
208
  ```
209
 
210
- ## Limitations
211
-
212
- - Processing time scales with number of sources, audio length and model size
213
-
214
  ## License
215
 
216
  Code: MIT License
 
66
  return None, "No reference WAV files found"
67
  if len(out_files) == 0:
68
  return None, "No output WAV files found"
69
+ if len(ref_files) != len(out_files):
70
+ return None, f"Number of reference files ({len(ref_files)}) must match number of output files ({len(out_files)}). Files must be in the same order."
71
 
72
  # Create manifest
73
  manifest = [{
 
94
  }
95
  layer_final = layer if layer is not None else model_defaults.get(model_name, 12)
96
 
97
+ # Check GPU availability - use all available GPUs on the space
98
+ max_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
99
 
100
  # Run experiment
101
  results_dir = compute_mapss_measures(
 
105
  alpha=alpha,
106
  verbose=True,
107
  max_gpus=max_gpus,
108
+ add_ci=False # Disable CI for faster processing in demo
109
  )
110
 
111
  # Create output ZIP at a fixed location
 
123
  files_added += 1
124
 
125
  if output_zip.exists() and files_added > 0:
126
+ return str(output_zip), f"Processing completed! Created ZIP with {files_added} files. Note: Output files must be in the same order as reference files."
127
  else:
128
  return None, f"Processing completed but no output files were generated. Check if embeddings were computed."
129
 
 
145
  - **Perceptual Matching (PM)**: Measures how closely an output perceptually aligns with its reference. Range: 0-1, higher is better.
146
  - **Perceptual Similarity (PS)**: Measures how well an output is separated from its interfering references. Range: 0-1, higher is better.
147
 
148
+ ## ⚠️ IMPORTANT: File Order Requirements
149
+
150
+ **Output files MUST be in the same order as reference files!**
151
+ - If references are: `speaker1.wav`, `speaker2.wav`, `speaker3.wav`
152
+ - Then outputs must be: `output1.wav`, `output2.wav`, `output3.wav`
153
+ - Where `output1` corresponds to `speaker1`, `output2` to `speaker2`, etc.
154
+
155
  ## Input Format
156
 
157
  Upload a ZIP file containing:
 
161
  │ ├── speaker1.wav
162
  │ ├── speaker2.wav
163
  │ └── ...
164
+ └── outputs/ # Separated outputs (SAME ORDER as references)
165
+ ├── separated1.wav # Must correspond to speaker1.wav
166
+ ├── separated2.wav # Must correspond to speaker2.wav
167
  └── ...
168
  ```
169
 
 
171
  - Format: .wav files
172
  - Sample rate: Any (automatically resampled to 16kHz)
173
  - Channels: Mono or stereo (converted to mono)
174
+ - **Number of files: Equal number of references and outputs**
175
+ - **Order: Output files must be in the same order as reference files**
176
 
177
  ## Output Format
178
 
179
  The tool generates a ZIP file containing:
180
+ - `ps_scores_{model}.csv`: PS scores for each source over time
181
+ - `pm_scores_{model}.csv`: PM scores for each source over time
182
  - `params.json`: Parameters used
183
  - `manifest_canonical.json`: File mapping and processing details
184
 
185
+ ### Score Interpretation
186
+ - **NaN values**: Appear in frames where fewer than 2 speakers are active
187
+ - **Valid scores**: Only computed when at least 2 speakers are active in a frame
188
+ - **Time resolution**: 20ms frames (configurable in code)
189
+
190
  ## Available Models
191
 
192
  | Model | Description | Default Layer | Use Case |
 
194
  | `raw` | Raw waveform features | N/A | Baseline comparison |
195
  | `wavlm` | WavLM Large | 24 | Strong performance |
196
  | `wav2vec2` | Wav2Vec2 Large | 24 | Best overall performance |
197
+ | `hubert` | HuBERT Large | 24 | Good for speech |
198
+ | `wavlm_base` | WavLM Base | 12 | Faster processing |
199
  | `wav2vec2_base` | Wav2Vec2 Base | 12 | Faster, good quality |
200
+ | `hubert_base` | HuBERT Base | 12 | Faster processing |
201
  | `wav2vec2_xlsr` | Wav2Vec2 XLSR-53 | 24 | Multilingual |
202
 
203
  ## Parameters
 
208
  - 0.0 = No normalization
209
  - 1.0 = Full normalization (recommended)
210
 
211
+ ## Processing Notes
212
+
213
+ - The system automatically detects which speakers are active in each frame
214
+ - PS/PM scores are only computed between active speakers
215
+ - Processing time scales with number of sources and audio length
216
+ - GPU acceleration is automatically used when available
217
+
218
  ## Citation
219
 
220
  If you use MAPSS, please cite:
 
229
  }
230
  ```
231
 
 
 
 
 
232
  ## License
233
 
234
  Code: MIT License
argshield.py CHANGED
@@ -7,7 +7,6 @@ import importlib.util
7
  from config import DEFAULT_ALPHA
8
  from models import get_model_config
9
 
10
- # Central table for default layers per model (kept identical to original table)
11
  MODEL_DEFAULT_LAYER = {
12
  "raw": None,
13
  "wavlm": 24,
@@ -31,7 +30,7 @@ def _read_manifest_py(path: Path):
31
  if spec is None or spec.loader is None:
32
  raise SystemExit(f"Could not load Python manifest: {path}")
33
  mod = importlib.util.module_from_spec(spec)
34
- spec.loader.exec_module(mod) # executes the .py file
35
 
36
  if not hasattr(mod, "MANIFEST"):
37
  raise SystemExit(f"Python manifest {path} must define a top-level variable MANIFEST")
 
7
  from config import DEFAULT_ALPHA
8
  from models import get_model_config
9
 
 
10
  MODEL_DEFAULT_LAYER = {
11
  "raw": None,
12
  "wavlm": 24,
 
30
  if spec is None or spec.loader is None:
31
  raise SystemExit(f"Could not load Python manifest: {path}")
32
  mod = importlib.util.module_from_spec(spec)
33
+ spec.loader.exec_module(mod)
34
 
35
  if not hasattr(mod, "MANIFEST"):
36
  raise SystemExit(f"Python manifest {path} must define a top-level variable MANIFEST")
audio.py CHANGED
@@ -1,15 +1,21 @@
1
- import librosa
2
  import numpy as np
3
  import pyloudnorm as pyln
4
  import torch
5
 
6
  from config import SILENCE_RATIO, SR
7
- from utils import hungarian, safe_corr_np
8
  import warnings
 
9
  warnings.filterwarnings("ignore", message="Possible clipped samples in output.")
10
 
11
 
12
  def loudness_normalize(wav, sr=SR, target_lufs=-23.0):
 
 
 
 
 
 
 
13
  meter = pyln.Meter(sr)
14
  loudness = meter.integrated_loudness(wav)
15
  normalized_wav = pyln.normalize.loudness(wav, loudness, target_lufs)
@@ -20,42 +26,53 @@ def loudness_normalize(wav, sr=SR, target_lufs=-23.0):
20
 
21
 
22
  def frame_rms_torch(sig, win, hop):
 
 
 
 
 
 
 
23
  dev = sig.device
24
  frames = sig.unfold(0, win, hop)
25
  if frames.size(0) and (frames.size(0) - 1) * hop == sig.numel() - win:
26
  frames = frames[:-1]
27
- rms = torch.sqrt((frames**2).mean(1) + 1e-12)
28
  return rms.to(dev)
29
 
30
 
31
- def make_union_voiced_mask(refs_tensors, win, hop):
 
 
 
 
 
 
 
 
 
 
32
  device = refs_tensors[0].device
33
- rms_vecs = [frame_rms_torch(r, win, hop) for r in refs_tensors]
34
- lengths = [v.numel() for v in rms_vecs]
 
 
 
 
 
 
 
 
35
  L_max = max(lengths)
36
- silent_union = torch.zeros(L_max, dtype=torch.bool, device=device)
37
- for idx, (rms, L) in enumerate(zip(rms_vecs, lengths)):
38
- thr = SILENCE_RATIO * torch.sqrt((refs_tensors[idx] ** 2).mean())
39
- sil = rms <= thr
40
- silent_union[:L] |= sil
41
- return ~silent_union
42
-
43
-
44
- def assign_outputs_to_refs_by_corr(ref_paths, out_paths):
45
- if not out_paths:
46
- return [None] * len(ref_paths)
47
- refs = [loudness_normalize(librosa.load(str(p), sr=SR)[0]) for p in ref_paths]
48
- outs = [loudness_normalize(librosa.load(str(p), sr=SR)[0]) for p in out_paths]
49
- n, m = len(refs), len(outs)
50
- K = max(n, m)
51
- C = np.ones((K, K), dtype=np.float64)
52
- for i in range(n):
53
- for j in range(m):
54
- r = safe_corr_np(refs[i], outs[j])
55
- C[i, j] = 1.0 - (r + 1.0) * 0.5 # lower = better
56
- ri, cj = hungarian(C)
57
- mapping = [None] * n
58
- for i, j in zip(ri, cj):
59
- if i < n and j < m:
60
- mapping[i] = int(j)
61
- return mapping
 
 
1
  import numpy as np
2
  import pyloudnorm as pyln
3
  import torch
4
 
5
  from config import SILENCE_RATIO, SR
 
6
  import warnings
7
+
8
  warnings.filterwarnings("ignore", message="Possible clipped samples in output.")
9
 
10
 
11
  def loudness_normalize(wav, sr=SR, target_lufs=-23.0):
12
+ """
13
+ Apply loudness normalization on an audio signal.
14
+ :param wav: waveform signal to normalize.
15
+ :param sr: sampling rate.
16
+ :param target_lufs: LUFS points to normalize to.
17
+ :return: normalized signal.
18
+ """
19
  meter = pyln.Meter(sr)
20
  loudness = meter.integrated_loudness(wav)
21
  normalized_wav = pyln.normalize.loudness(wav, loudness, target_lufs)
 
26
 
27
 
28
  def frame_rms_torch(sig, win, hop):
29
+ """
30
+ Calculates the RMS of a signal with a moving window.
31
+ :param sig: signal for calculation.
32
+ :param win: analysis window size.
33
+ :param hop: analysis window hop size.
34
+ :return: RMS of signal.
35
+ """
36
  dev = sig.device
37
  frames = sig.unfold(0, win, hop)
38
  if frames.size(0) and (frames.size(0) - 1) * hop == sig.numel() - win:
39
  frames = frames[:-1]
40
+ rms = torch.sqrt((frames ** 2).mean(1) + 1e-12)
41
  return rms.to(dev)
42
 
43
 
44
+ def compute_speaker_activity_masks(refs_tensors, win, hop):
45
+ """
46
+ Computes individual voice activity for each speaker and determines which frames
47
+ have at least 2 active speakers.
48
+ :param refs_tensors: references that compose the mixture.
49
+ :param win: analysis window size.
50
+ :param hop: analysis window hop size.
51
+ :return: (multi_speaker_mask, individual_speaker_masks)
52
+ - multi_speaker_mask: boolean mask of frames where at least 2 speakers are active
53
+ - individual_speaker_masks: list of boolean masks, one per speaker
54
+ """
55
  device = refs_tensors[0].device
56
+ individual_masks = []
57
+ lengths = []
58
+
59
+ for ref in refs_tensors:
60
+ rms = frame_rms_torch(ref, win, hop)
61
+ threshold = SILENCE_RATIO * torch.sqrt((ref ** 2).mean())
62
+ voiced = rms > threshold
63
+ individual_masks.append(voiced)
64
+ lengths.append(voiced.numel())
65
+
66
  L_max = max(lengths)
67
+ padded_masks = []
68
+ for mask, L in zip(individual_masks, lengths):
69
+ if L < L_max:
70
+ padded = torch.cat([mask, torch.zeros(L_max - L, dtype=torch.bool, device=device)])
71
+ else:
72
+ padded = mask
73
+ padded_masks.append(padded)
74
+
75
+ stacked = torch.stack(padded_masks, dim=0)
76
+ active_count = stacked.sum(dim=0)
77
+ multi_speaker_mask = active_count >= 2
78
+ return multi_speaker_mask, padded_masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.py CHANGED
@@ -1,35 +1,35 @@
1
- import os
2
- import torch
3
-
4
- import warnings
5
- warnings.filterwarnings(
6
- "ignore",
7
- category=UserWarning,
8
- message=r"^expandable_segments not supported on this platform"
9
- )
10
-
11
- SR = 16_000
12
- RESULTS_ROOT = "results"
13
- BATCH_SIZE = 2
14
- ENERGY_WIN_MS = 20
15
- ENERGY_HOP_MS = 20
16
- SILENCE_RATIO = 0.1
17
- EPS = 1e-4
18
- COV_TOL = 1e-6
19
-
20
- DEFAULT_LAYER = 2
21
- DEFAULT_ADD_CI = True
22
- DEFAULT_DELTA_CI = 0.05
23
- DEFAULT_ALPHA = 1.0
24
-
25
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.6"
26
- os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
27
-
28
- torch.backends.cudnn.benchmark = True
29
- torch.backends.cudnn.deterministic = False
30
- torch.backends.cudnn.enabled = True
31
-
32
- # Only set CUDA memory fraction if we're not in the main process on HF Spaces
33
- if not (os.environ.get("SPACE_ID") and torch.cuda.is_available()):
34
- if torch.cuda.is_available():
35
- torch.cuda.set_per_process_memory_fraction(0.8)
 
1
+ """
2
+ Basic configuration and default values used in the MAPSS computations.
3
+ """
4
+ import os
5
+ import torch
6
+ import warnings
7
+ warnings.filterwarnings(
8
+ "ignore",
9
+ category=UserWarning,
10
+ message=r"^expandable_segments not supported on this platform"
11
+ )
12
+
13
+ SR = 16_000
14
+ RESULTS_ROOT = "results"
15
+ BATCH_SIZE = 2
16
+ ENERGY_WIN_MS = 20
17
+ ENERGY_HOP_MS = 20
18
+ SILENCE_RATIO = 0.1
19
+ EPS = 1e-4
20
+ COV_TOL = 1e-6
21
+
22
+ DEFAULT_LAYER = 2
23
+ DEFAULT_ADD_CI = True
24
+ DEFAULT_DELTA_CI = 0.05
25
+ DEFAULT_ALPHA = 1.0
26
+
27
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.6"
28
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
29
+
30
+ torch.backends.cudnn.benchmark = True
31
+ torch.backends.cudnn.deterministic = False
32
+ torch.backends.cudnn.enabled = True
33
+
34
+ if torch.cuda.is_available():
35
+ torch.cuda.set_per_process_memory_fraction(0.8)
distortions.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import librosa
2
  import numpy as np
3
  from numpy.fft import irfft, rfft, rfftfreq
@@ -156,7 +160,7 @@ def frame_distortions(
156
  return distortions
157
 
158
 
159
- def apply_adv_distortions(ref, distortion_keys, sr=SR):
160
  frame_len = int(ENERGY_WIN_MS * sr / 1000)
161
  n_frames = int(np.ceil(len(ref) / frame_len))
162
  pad_len = n_frames * frame_len - len(ref)
@@ -222,7 +226,7 @@ def apply_adv_distortions(ref, distortion_keys, sr=SR):
222
  return list(out.values())
223
 
224
 
225
- def apply_distortions(ref, distortion_keys, sr=SR):
226
  distortions = {}
227
  X = rfft(ref)
228
  freqs = rfftfreq(len(ref), 1 / sr)
 
1
+ """
2
+ Distortions banks for the PS and the PM computations.
3
+ """
4
+
5
  import librosa
6
  import numpy as np
7
  from numpy.fft import irfft, rfft, rfftfreq
 
160
  return distortions
161
 
162
 
163
+ def apply_pm_distortions(ref, distortion_keys, sr=SR):
164
  frame_len = int(ENERGY_WIN_MS * sr / 1000)
165
  n_frames = int(np.ceil(len(ref) / frame_len))
166
  pad_len = n_frames * frame_len - len(ref)
 
226
  return list(out.values())
227
 
228
 
229
+ def apply_ps_distortions(ref, distortion_keys, sr=SR):
230
  distortions = {}
231
  X = rfft(ref)
232
  freqs = rfftfreq(len(ref), 1 / sr)
engine.py CHANGED
@@ -4,14 +4,12 @@ from concurrent.futures import ThreadPoolExecutor
4
  from datetime import datetime
5
  import librosa
6
  import pandas as pd
7
- import numpy as np
8
  from audio import (
9
- assign_outputs_to_refs_by_corr,
10
  loudness_normalize,
11
- make_union_voiced_mask,
12
  )
13
  from config import *
14
- from distortions import apply_adv_distortions, apply_distortions
15
  from metrics import (
16
  compute_pm,
17
  compute_ps,
@@ -38,6 +36,23 @@ def compute_mapss_measures(
38
  verbose=False,
39
  max_gpus=None,
40
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  gpu_distributor = GPUWorkDistributor(max_gpus)
42
  ngpu = get_gpu_count(max_gpus)
43
 
@@ -64,13 +79,17 @@ def compute_mapss_measures(
64
 
65
  for m, mix_entries in zip(canon_mix, mixture_entries):
66
  for algo, out_list in (m.systems or {}).items():
67
- mapping = assign_outputs_to_refs_by_corr(
68
- [e["ref"] for e in mix_entries], out_list
69
- )
 
 
 
 
 
 
70
  for idx, e in enumerate(mix_entries):
71
- j = mapping[idx]
72
- if j is not None:
73
- e["outs"][algo] = out_list[j]
74
 
75
  if algos is None:
76
  algos_to_run = sorted(
@@ -114,6 +133,7 @@ def compute_mapss_measures(
114
 
115
  print(f"Starting experiment {exp_id} with {ngpu} GPUs")
116
  print(f"Results will be saved to: {exp_root}")
 
117
 
118
  clear_gpu_memory()
119
  get_gpu_memory_info(verbose)
@@ -128,63 +148,53 @@ def compute_mapss_measures(
128
  all_refs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
129
 
130
  if verbose:
131
- print("Computing voiced masks...")
132
 
133
  win = int(ENERGY_WIN_MS * SR / 1000)
134
  hop = int(ENERGY_HOP_MS * SR / 1000)
135
- voiced_mask_mix = []
136
- total_frames_per_mix = [] # Store total frames for each mixture
 
137
 
138
  for i, mix in enumerate(mixture_entries):
139
  if verbose:
140
- print(f" Computing mask for mixture {i + 1}/{len(mixture_entries)}")
141
 
142
  if ngpu > 0:
143
  with torch.cuda.device(0):
144
  refs_for_mix = [all_refs[e["id"]].cuda() for e in mix]
145
- mask = make_union_voiced_mask(refs_for_mix, win, hop)
146
- voiced_mask_mix.append(mask.cpu())
147
- total_frames_per_mix.append(mask.shape[0])
148
- # Explicitly delete GPU tensors
149
  for ref in refs_for_mix:
150
  del ref
151
  torch.cuda.empty_cache()
152
  else:
153
  refs_for_mix = [all_refs[e["id"]].cpu() for e in mix]
154
- mask = make_union_voiced_mask(refs_for_mix, win, hop)
155
- voiced_mask_mix.append(mask.cpu())
156
- total_frames_per_mix.append(mask.shape[0])
 
157
 
158
  ordered_speakers = [e["id"] for e in flat_entries]
159
-
160
- # Initialize storage for all mixtures and algorithms
161
- all_mixture_results = {} # mixture_id -> {algo -> {model -> data}}
162
-
163
  for mix_idx, (mix_canon, mix_entries) in enumerate(zip(canon_mix, mixture_entries)):
164
  mixture_id = mix_canon.mixture_id
165
  all_mixture_results[mixture_id] = {}
166
-
167
- # Get total frames for this mixture
168
  total_frames = total_frames_per_mix[mix_idx]
169
-
170
- # Get speakers for this mixture
171
  mixture_speakers = [e["id"] for e in mix_entries]
172
 
173
  for algo_idx, algo in enumerate(algos_to_run):
174
  if verbose:
175
  print(f"\nProcessing Mixture {mixture_id}, Algorithm {algo_idx + 1}/{len(algos_to_run)}: {algo}")
176
-
177
- # Remove the old algo_dir creation here - we don't need these empty folders anymore
178
-
179
  all_outs = {}
180
  missing = []
181
-
182
  for e in mix_entries:
183
  assigned_path = e.get("outs", {}).get(algo)
184
  if assigned_path is None:
185
  missing.append((e["mixture"], e["id"]))
186
  continue
187
-
188
  wav, _ = librosa.load(str(assigned_path), sr=SR)
189
  all_outs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
190
 
@@ -201,11 +211,9 @@ def compute_mapss_measures(
201
  warnings.warn(f"[{algo}] No outputs for mixture {mixture_id}. Skipping.")
202
  continue
203
 
204
- # Initialize storage for this algorithm
205
  if algo not in all_mixture_results[mixture_id]:
206
  all_mixture_results[mixture_id][algo] = {}
207
 
208
- # Initialize frame-wise storage with NaN for all frames
209
  ps_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
210
  pm_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
211
  ps_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
@@ -224,7 +232,6 @@ def compute_mapss_measures(
224
  model_wrapper, layer_eff = load_model(mname, layer, max_gpus)
225
  get_gpu_memory_info(verbose)
226
 
227
- # Process only this mixture
228
  speakers_this_mix = [e for e in mix_entries if e["id"] in all_outs]
229
  if not speakers_this_mix:
230
  continue
@@ -232,22 +239,25 @@ def compute_mapss_measures(
232
  if verbose:
233
  print(f" Processing {metric_type} for mixture {mixture_id}")
234
 
235
- all_signals_mix = []
236
- all_masks_mix = []
237
- all_labels_mix = []
238
 
239
- for e in speakers_this_mix:
 
 
 
240
  s = e["id"]
241
 
242
  if metric_type == "PS":
243
  dists = [
244
  loudness_normalize(d)
245
- for d in apply_distortions(all_refs[s].numpy(), "all")
246
  ]
247
  else:
248
  dists = [
249
  loudness_normalize(d)
250
- for d in apply_adv_distortions(
251
  all_refs[s].numpy(), "all"
252
  )
253
  ]
@@ -255,19 +265,20 @@ def compute_mapss_measures(
255
  sigs = [all_refs[s].numpy(), all_outs[s].numpy()] + dists
256
  lbls = ["ref", "out"] + [f"d{i}" for i in range(len(dists))]
257
 
258
- masks = [voiced_mask_mix[mix_idx]] * len(sigs)
259
- all_signals_mix.extend(sigs)
260
- all_masks_mix.extend(masks)
261
- all_labels_mix.extend([f"{s}-{l}" for l in lbls])
 
 
 
262
 
263
- try:
264
- # Process in smaller batches
265
  batch_size = min(2, BATCH_SIZE)
266
  embeddings_list = []
267
 
268
- for i in range(0, len(all_signals_mix), batch_size):
269
- batch_sigs = all_signals_mix[i:i + batch_size]
270
- batch_masks = all_masks_mix[i:i + batch_size]
271
 
272
  batch_embs = embed_batch(
273
  batch_sigs,
@@ -283,139 +294,156 @@ def compute_mapss_measures(
283
  torch.cuda.empty_cache()
284
 
285
  if embeddings_list:
286
- embeddings = torch.cat(embeddings_list, dim=0)
287
- E, L, D = embeddings.shape
 
288
 
289
- if L == 0:
290
- if verbose:
291
- print(
292
- f" WARNING: mixture {mixture_id} produced 0 frames after masking; skipping.")
293
- continue
 
294
 
295
- # Get valid frame indices for this mixture
296
- mask = voiced_mask_mix[mix_idx]
297
- valid_frame_indices = torch.where(mask)[0].tolist()
298
-
299
- if verbose:
300
- print(f" Computing {metric_type} scores for {mname}...")
301
-
302
- # Process frames with their stored embeddings and labels
303
- with ThreadPoolExecutor(
304
- max_workers=min(2, ngpu if ngpu > 0 else 1)
305
- ) as executor:
306
-
307
- def process_frame(f, frame_idx, embeddings_mix, labels_mix):
308
- try:
309
- frame_emb = embeddings_mix[:, f, :].detach().cpu().numpy()
310
-
311
- if add_ci:
312
- coords_d, coords_c, eigvals, k_sub_gauss = (
313
- gpu_distributor.execute_on_gpu(
314
- diffusion_map_torch,
315
- frame_emb,
316
- labels_mix,
317
- alpha=alpha,
318
- eig_solver="full",
319
- return_eigs=True,
320
- return_complement=True,
321
- return_cval=add_ci,
322
- )
323
- )
324
- else:
325
- coords_d = gpu_distributor.execute_on_gpu(
326
- diffusion_map_torch,
327
- frame_emb,
328
- labels_mix,
329
- alpha=alpha,
330
- eig_solver="full",
331
- return_eigs=False,
332
- return_complement=False,
333
- return_cval=False,
334
- )
335
- coords_c = None
336
- eigvals = None
337
- k_sub_gauss = 1
338
-
339
- if metric_type == "PS":
340
- score = compute_ps(
341
- coords_d, labels_mix, max_gpus
342
- )
343
- bias = prob = None
344
- if add_ci:
345
- bias, prob = ps_ci_components_full(
346
- coords_d,
347
- coords_c,
348
- eigvals,
349
- labels_mix,
350
- delta=DEFAULT_DELTA_CI,
351
- )
352
- return frame_idx, "PS", score, bias, prob
353
- else:
354
- score = compute_pm(
355
- coords_d, labels_mix, "gamma", max_gpus
356
- )
357
- bias = prob = None
358
- if add_ci:
359
- bias, prob = pm_ci_components_full(
360
- coords_d,
361
- coords_c,
362
- eigvals,
363
- labels_mix,
364
- delta=DEFAULT_DELTA_CI,
365
- K=k_sub_gauss,
366
- )
367
- return frame_idx, "PM", score, bias, prob
368
-
369
- except Exception as ex:
370
- if verbose:
371
- print(f" ERROR frame {frame_idx}: {ex}")
372
- return None
373
-
374
- futures = [
375
- executor.submit(process_frame, f, valid_frame_indices[f], embeddings,
376
- all_labels_mix)
377
- for f in range(L)
378
- ]
379
-
380
- for fut in futures:
381
- result = fut.result()
382
- if result is None:
383
- continue
384
-
385
- frame_idx, metric, score, bias, prob = result
386
-
387
- if metric == "PS":
388
- for sp in score:
389
- if sp in mixture_speakers:
390
- ps_frames[mname][sp][frame_idx] = score[sp]
391
- if add_ci and bias is not None:
392
- ps_bias_frames[mname][sp][frame_idx] = bias[sp]
393
- ps_prob_frames[mname][sp][frame_idx] = prob[sp]
394
- else:
395
- for sp in score:
396
- if sp in mixture_speakers:
397
- pm_frames[mname][sp][frame_idx] = score[sp]
398
- if add_ci and bias is not None:
399
- pm_bias_frames[mname][sp][frame_idx] = bias[sp]
400
- pm_prob_frames[mname][sp][frame_idx] = prob[sp]
401
-
402
- except Exception as ex:
403
  if verbose:
404
- print(f" ERROR processing mixture {mixture_id}: {ex}")
405
  continue
406
- finally:
407
- # Always clean up after processing a mixture
408
- del all_signals_mix, all_masks_mix
409
- if 'embeddings_list' in locals():
410
- del embeddings_list
411
- clear_gpu_memory()
412
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
  del model_wrapper
415
  clear_gpu_memory()
416
  gc.collect()
417
 
418
- # Store results for this mixture and algorithm
419
  all_mixture_results[mixture_id][algo][mname] = {
420
  'ps_frames': ps_frames[mname],
421
  'pm_frames': pm_frames[mname],
@@ -426,20 +454,16 @@ def compute_mapss_measures(
426
  'total_frames': total_frames
427
  }
428
 
429
- # Save results for this mixture after processing all algorithms
430
  if verbose:
431
- print(f" Saving results for mixture {mixture_id}...")
432
 
433
- # Create timestamps in milliseconds - using lowercase hop
434
  timestamps_ms = [i * hop * 1000 / SR for i in range(total_frames)]
435
 
436
  for model in models:
437
- # Prepare PS data
438
  ps_data = {'timestamp_ms': timestamps_ms}
439
  pm_data = {'timestamp_ms': timestamps_ms}
440
  ci_data = {'timestamp_ms': timestamps_ms} if add_ci else None
441
 
442
- # Combine data from all algorithms for this mixture
443
  for algo in algos_to_run:
444
  if algo not in all_mixture_results[mixture_id]:
445
  continue
@@ -448,7 +472,6 @@ def compute_mapss_measures(
448
 
449
  model_data = all_mixture_results[mixture_id][algo][model]
450
 
451
- # Add PS data
452
  for speaker in mixture_speakers:
453
  col_name = f"{algo}_{speaker}"
454
  ps_data[col_name] = model_data['ps_frames'][speaker]
@@ -460,7 +483,6 @@ def compute_mapss_measures(
460
  ci_data[f"{algo}_{speaker}_pm_bias"] = model_data['pm_bias_frames'][speaker]
461
  ci_data[f"{algo}_{speaker}_pm_prob"] = model_data['pm_prob_frames'][speaker]
462
 
463
- # Save CSV files for this mixture
464
  mixture_dir = os.path.join(exp_root, mixture_id)
465
  os.makedirs(mixture_dir, exist_ok=True)
466
 
@@ -487,9 +509,8 @@ def compute_mapss_measures(
487
  print(f"\nEXPERIMENT COMPLETED")
488
  print(f"Results saved to: {exp_root}")
489
 
490
- del all_refs, voiced_mask_mix
491
 
492
- # Import and call the cleanup function
493
  from models import cleanup_all_models
494
  cleanup_all_models()
495
 
 
4
  from datetime import datetime
5
  import librosa
6
  import pandas as pd
 
7
  from audio import (
 
8
  loudness_normalize,
9
+ compute_speaker_activity_masks,
10
  )
11
  from config import *
12
+ from distortions import apply_pm_distortions, apply_ps_distortions
13
  from metrics import (
14
  compute_pm,
15
  compute_ps,
 
36
  verbose=False,
37
  max_gpus=None,
38
  ):
39
+ """
40
+ Compute MAPSS measures (PM, PS, and their errors). Data is saved to csv files.
41
+
42
+ :param models: backbone self-supervised models.
43
+ :param mixtures: data to process from _read_manifest
44
+ :param systems: specific systems (algos and data)
45
+ :param algos: specific algorithms to use
46
+ :param experiment_id: user-specified name for experiment
47
+ :param layer: transformer layer of model to consider
48
+ :param add_ci: True will compute error radius and tail bounds. False will not.
49
+ :param alpha: normalization factor of the diffusion maps. Lives in [0, 1].
50
+ :param seed: random seed number.
51
+ :param on_missing: "skip" when missing values or throw an "error".
52
+ :param verbose: True will print process info to console during runtime. False will minimize it.
53
+ :param max_gpus: maximal amount of GPUs the program tries to utilize in parallel.
54
+
55
+ """
56
  gpu_distributor = GPUWorkDistributor(max_gpus)
57
  ngpu = get_gpu_count(max_gpus)
58
 
 
79
 
80
  for m, mix_entries in zip(canon_mix, mixture_entries):
81
  for algo, out_list in (m.systems or {}).items():
82
+ if len(out_list) != len(mix_entries):
83
+ msg = f"[{algo}] Number of outputs ({len(out_list)}) does not match number of references ({len(mix_entries)}) for mixture {m.mixture_id}"
84
+ if on_missing == "error":
85
+ raise ValueError(msg)
86
+ else:
87
+ if verbose:
88
+ warnings.warn(msg + " Skipping this algorithm.")
89
+ continue
90
+
91
  for idx, e in enumerate(mix_entries):
92
+ e["outs"][algo] = out_list[idx]
 
 
93
 
94
  if algos is None:
95
  algos_to_run = sorted(
 
133
 
134
  print(f"Starting experiment {exp_id} with {ngpu} GPUs")
135
  print(f"Results will be saved to: {exp_root}")
136
+ print("NOTE: Output files must be provided in the same order as reference files.")
137
 
138
  clear_gpu_memory()
139
  get_gpu_memory_info(verbose)
 
148
  all_refs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
149
 
150
  if verbose:
151
+ print("Computing speaker activity masks...")
152
 
153
  win = int(ENERGY_WIN_MS * SR / 1000)
154
  hop = int(ENERGY_HOP_MS * SR / 1000)
155
+ multi_speaker_masks_mix = []
156
+ individual_speaker_masks_mix = []
157
+ total_frames_per_mix = []
158
 
159
  for i, mix in enumerate(mixture_entries):
160
  if verbose:
161
+ print(f" Computing masks for mixture {i + 1}/{len(mixture_entries)}")
162
 
163
  if ngpu > 0:
164
  with torch.cuda.device(0):
165
  refs_for_mix = [all_refs[e["id"]].cuda() for e in mix]
166
+ multi_mask, individual_masks = compute_speaker_activity_masks(refs_for_mix, win, hop)
167
+ multi_speaker_masks_mix.append(multi_mask.cpu())
168
+ individual_speaker_masks_mix.append([m.cpu() for m in individual_masks])
169
+ total_frames_per_mix.append(multi_mask.shape[0])
170
  for ref in refs_for_mix:
171
  del ref
172
  torch.cuda.empty_cache()
173
  else:
174
  refs_for_mix = [all_refs[e["id"]].cpu() for e in mix]
175
+ multi_mask, individual_masks = compute_speaker_activity_masks(refs_for_mix, win, hop)
176
+ multi_speaker_masks_mix.append(multi_mask.cpu())
177
+ individual_speaker_masks_mix.append([m.cpu() for m in individual_masks])
178
+ total_frames_per_mix.append(multi_mask.shape[0])
179
 
180
  ordered_speakers = [e["id"] for e in flat_entries]
181
+ all_mixture_results = {}
 
 
 
182
  for mix_idx, (mix_canon, mix_entries) in enumerate(zip(canon_mix, mixture_entries)):
183
  mixture_id = mix_canon.mixture_id
184
  all_mixture_results[mixture_id] = {}
 
 
185
  total_frames = total_frames_per_mix[mix_idx]
 
 
186
  mixture_speakers = [e["id"] for e in mix_entries]
187
 
188
  for algo_idx, algo in enumerate(algos_to_run):
189
  if verbose:
190
  print(f"\nProcessing Mixture {mixture_id}, Algorithm {algo_idx + 1}/{len(algos_to_run)}: {algo}")
 
 
 
191
  all_outs = {}
192
  missing = []
 
193
  for e in mix_entries:
194
  assigned_path = e.get("outs", {}).get(algo)
195
  if assigned_path is None:
196
  missing.append((e["mixture"], e["id"]))
197
  continue
 
198
  wav, _ = librosa.load(str(assigned_path), sr=SR)
199
  all_outs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
200
 
 
211
  warnings.warn(f"[{algo}] No outputs for mixture {mixture_id}. Skipping.")
212
  continue
213
 
 
214
  if algo not in all_mixture_results[mixture_id]:
215
  all_mixture_results[mixture_id][algo] = {}
216
 
 
217
  ps_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
218
  pm_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
219
  ps_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
 
232
  model_wrapper, layer_eff = load_model(mname, layer, max_gpus)
233
  get_gpu_memory_info(verbose)
234
 
 
235
  speakers_this_mix = [e for e in mix_entries if e["id"] in all_outs]
236
  if not speakers_this_mix:
237
  continue
 
239
  if verbose:
240
  print(f" Processing {metric_type} for mixture {mixture_id}")
241
 
242
+ multi_speaker_mask = multi_speaker_masks_mix[mix_idx]
243
+ individual_masks = individual_speaker_masks_mix[mix_idx]
244
+ valid_frame_indices = torch.where(multi_speaker_mask)[0].tolist()
245
 
246
+ speaker_signals = {}
247
+ speaker_labels = {}
248
+
249
+ for speaker_idx, e in enumerate(speakers_this_mix):
250
  s = e["id"]
251
 
252
  if metric_type == "PS":
253
  dists = [
254
  loudness_normalize(d)
255
+ for d in apply_ps_distortions(all_refs[s].numpy(), "all")
256
  ]
257
  else:
258
  dists = [
259
  loudness_normalize(d)
260
+ for d in apply_pm_distortions(
261
  all_refs[s].numpy(), "all"
262
  )
263
  ]
 
265
  sigs = [all_refs[s].numpy(), all_outs[s].numpy()] + dists
266
  lbls = ["ref", "out"] + [f"d{i}" for i in range(len(dists))]
267
 
268
+ speaker_signals[s] = sigs
269
+ speaker_labels[s] = [f"{s}-{l}" for l in lbls]
270
+
271
+ all_embeddings = {}
272
+ for s in speaker_signals:
273
+ sigs = speaker_signals[s]
274
+ masks = [multi_speaker_mask] * len(sigs)
275
 
 
 
276
  batch_size = min(2, BATCH_SIZE)
277
  embeddings_list = []
278
 
279
+ for i in range(0, len(sigs), batch_size):
280
+ batch_sigs = sigs[i:i + batch_size]
281
+ batch_masks = masks[i:i + batch_size]
282
 
283
  batch_embs = embed_batch(
284
  batch_sigs,
 
294
  torch.cuda.empty_cache()
295
 
296
  if embeddings_list:
297
+ all_embeddings[s] = torch.cat(embeddings_list, dim=0)
298
+ else:
299
+ all_embeddings[s] = torch.empty(0, 0, 0)
300
 
301
+ if not all_embeddings or all(e.numel() == 0 for e in all_embeddings.values()):
302
+ if verbose:
303
+ print(f"WARNING: mixture {mixture_id} produced 0 frames after masking; skipping.")
304
+ continue
305
+
306
+ L = next(iter(all_embeddings.values())).shape[1] if all_embeddings else 0
307
 
308
+ if L == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  if verbose:
310
+ print(f"WARNING: mixture {mixture_id} produced 0 frames after masking; skipping.")
311
  continue
312
+
313
+ if verbose:
314
+ print(f"Computing {metric_type} scores for {mname}...")
315
+
316
+ with ThreadPoolExecutor(
317
+ max_workers=min(2, ngpu if ngpu > 0 else 1)
318
+ ) as executor:
319
+
320
+ def process_frame(f, frame_idx, all_embeddings_dict, speaker_labels_dict, individual_masks_list,
321
+ speaker_indices):
322
+ try:
323
+ active_speakers = []
324
+ for spk_idx, spk_id in enumerate(speaker_indices):
325
+ if individual_masks_list[spk_idx][frame_idx]:
326
+ active_speakers.append(spk_id)
327
+
328
+ if len(active_speakers) < 2:
329
+ return frame_idx, metric_type, {}, None, None
330
+
331
+ frame_embeddings = []
332
+ frame_labels = []
333
+ for spk_id in active_speakers:
334
+ spk_embs = all_embeddings_dict[spk_id][:, f, :]
335
+ frame_embeddings.append(spk_embs)
336
+ frame_labels.extend(speaker_labels_dict[spk_id])
337
+
338
+ frame_emb = torch.cat(frame_embeddings, dim=0).detach().cpu().numpy()
339
+
340
+ if add_ci:
341
+ coords_d, coords_c, eigvals, k_sub_gauss = (
342
+ gpu_distributor.execute_on_gpu(
343
+ diffusion_map_torch,
344
+ frame_emb,
345
+ frame_labels,
346
+ alpha=alpha,
347
+ eig_solver="full",
348
+ return_eigs=True,
349
+ return_complement=True,
350
+ return_cval=add_ci,
351
+ )
352
+ )
353
+ else:
354
+ coords_d = gpu_distributor.execute_on_gpu(
355
+ diffusion_map_torch,
356
+ frame_emb,
357
+ frame_labels,
358
+ alpha=alpha,
359
+ eig_solver="full",
360
+ return_eigs=False,
361
+ return_complement=False,
362
+ return_cval=False,
363
+ )
364
+ coords_c = None
365
+ eigvals = None
366
+ k_sub_gauss = 1
367
+
368
+ if metric_type == "PS":
369
+ score = compute_ps(
370
+ coords_d, frame_labels, max_gpus
371
+ )
372
+ bias = prob = None
373
+ if add_ci:
374
+ bias, prob = ps_ci_components_full(
375
+ coords_d,
376
+ coords_c,
377
+ eigvals,
378
+ frame_labels,
379
+ delta=DEFAULT_DELTA_CI,
380
+ )
381
+ return frame_idx, "PS", score, bias, prob
382
+ else:
383
+ score = compute_pm(
384
+ coords_d, frame_labels, "gamma", max_gpus
385
+ )
386
+ bias = prob = None
387
+ if add_ci:
388
+ bias, prob = pm_ci_components_full(
389
+ coords_d,
390
+ coords_c,
391
+ eigvals,
392
+ frame_labels,
393
+ delta=DEFAULT_DELTA_CI,
394
+ K=k_sub_gauss,
395
+ )
396
+ return frame_idx, "PM", score, bias, prob
397
+
398
+ except Exception as ex:
399
+ if verbose:
400
+ print(f"ERROR frame {frame_idx}: {ex}")
401
+ return None
402
+
403
+ speaker_ids = [e["id"] for e in speakers_this_mix]
404
+
405
+ futures = [
406
+ executor.submit(
407
+ process_frame,
408
+ f,
409
+ valid_frame_indices[f],
410
+ all_embeddings,
411
+ speaker_labels,
412
+ individual_masks,
413
+ speaker_ids
414
+ )
415
+ for f in range(L)
416
+ ]
417
+
418
+ for fut in futures:
419
+ result = fut.result()
420
+ if result is None:
421
+ continue
422
+
423
+ frame_idx, metric, score, bias, prob = result
424
+
425
+ if metric == "PS":
426
+ for sp in mixture_speakers:
427
+ if sp in score:
428
+ ps_frames[mname][sp][frame_idx] = score[sp]
429
+ if add_ci and bias is not None and sp in bias:
430
+ ps_bias_frames[mname][sp][frame_idx] = bias[sp]
431
+ ps_prob_frames[mname][sp][frame_idx] = prob[sp]
432
+ else:
433
+ for sp in mixture_speakers:
434
+ if sp in score:
435
+ pm_frames[mname][sp][frame_idx] = score[sp]
436
+ if add_ci and bias is not None and sp in bias:
437
+ pm_bias_frames[mname][sp][frame_idx] = bias[sp]
438
+ pm_prob_frames[mname][sp][frame_idx] = prob[sp]
439
+
440
+ clear_gpu_memory()
441
+ gc.collect()
442
 
443
  del model_wrapper
444
  clear_gpu_memory()
445
  gc.collect()
446
 
 
447
  all_mixture_results[mixture_id][algo][mname] = {
448
  'ps_frames': ps_frames[mname],
449
  'pm_frames': pm_frames[mname],
 
454
  'total_frames': total_frames
455
  }
456
 
 
457
  if verbose:
458
+ print(f"Saving results for mixture {mixture_id}...")
459
 
 
460
  timestamps_ms = [i * hop * 1000 / SR for i in range(total_frames)]
461
 
462
  for model in models:
 
463
  ps_data = {'timestamp_ms': timestamps_ms}
464
  pm_data = {'timestamp_ms': timestamps_ms}
465
  ci_data = {'timestamp_ms': timestamps_ms} if add_ci else None
466
 
 
467
  for algo in algos_to_run:
468
  if algo not in all_mixture_results[mixture_id]:
469
  continue
 
472
 
473
  model_data = all_mixture_results[mixture_id][algo][model]
474
 
 
475
  for speaker in mixture_speakers:
476
  col_name = f"{algo}_{speaker}"
477
  ps_data[col_name] = model_data['ps_frames'][speaker]
 
483
  ci_data[f"{algo}_{speaker}_pm_bias"] = model_data['pm_bias_frames'][speaker]
484
  ci_data[f"{algo}_{speaker}_pm_prob"] = model_data['pm_prob_frames'][speaker]
485
 
 
486
  mixture_dir = os.path.join(exp_root, mixture_id)
487
  os.makedirs(mixture_dir, exist_ok=True)
488
 
 
509
  print(f"\nEXPERIMENT COMPLETED")
510
  print(f"Results saved to: {exp_root}")
511
 
512
+ del all_refs, multi_speaker_masks_mix, individual_speaker_masks_mix
513
 
 
514
  from models import cleanup_all_models
515
  cleanup_all_models()
516
 
main.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
  from pathlib import Path
3
  from engine import compute_mapss_measures
 
1
+ """
2
+ Entry point from the CLI into the MAPSS calculation.
3
+
4
+ IMPORTANT: Output files must be provided in the same order as reference files.
5
+ For example, if references are ["ref1.wav", "ref2.wav"],
6
+ then outputs must be ["out1.wav", "out2.wav"] in that exact order.
7
+ """
8
+
9
  from __future__ import annotations
10
  from pathlib import Path
11
  from engine import compute_mapss_measures
metrics.py CHANGED
@@ -1,17 +1,20 @@
1
-
2
  import math
3
-
4
  import numpy as np
5
  import torch
6
  from scipy.special import gammaincc
7
  from scipy.stats import gamma
8
 
9
- from config import COV_TOL, DEFAULT_DELTA_CI
10
  from utils import get_gpu_count, mahalanobis_torch, safe_cov_torch
11
 
12
 
13
  def pm_tail_gamma(d_out_sq, sq_dists):
14
- """PM tail gamma exactly as original."""
 
 
 
 
 
15
  mu = sq_dists.mean().item()
16
  var = sq_dists.var(unbiased=True).item()
17
  if var == 0.0:
@@ -22,7 +25,9 @@ def pm_tail_gamma(d_out_sq, sq_dists):
22
 
23
 
24
  def pm_tail_rank(d_out_sq, sq_dists):
25
- """PM tail rank exactly as original."""
 
 
26
  rank = int((sq_dists < d_out_sq).sum().item())
27
  n = sq_dists.numel()
28
  return 1.0 - (rank + 0.5) / (n + 1.0)
@@ -43,7 +48,23 @@ def diffusion_map_torch(
43
  return_complement=False,
44
  return_cval=False,
45
  ):
46
- """Diffusion map computation exactly as original."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")
48
  X = torch.as_tensor(X_np, dtype=torch.float32, device=device)
49
  N = X.shape[0]
@@ -141,6 +162,13 @@ def diffusion_map_torch(
141
 
142
 
143
  def compute_ps(coords, labels, max_gpus=None):
 
 
 
 
 
 
 
144
  ngpu = get_gpu_count(max_gpus)
145
 
146
  if ngpu == 0:
@@ -171,8 +199,7 @@ def compute_ps(coords, labels, max_gpus=None):
171
  out[s] = (1 - A / (A + B_min + 1e-6)).item()
172
  return out
173
 
174
- # GPU version
175
- device = min(ngpu - 1, 1) # Use second GPU if available
176
  device_str = f"cuda:{device}"
177
  coords_t = torch.tensor(coords, device=device_str)
178
  spks_here = sorted({l.split("-")[0] for l in labels})
@@ -212,6 +239,14 @@ def compute_ps(coords, labels, max_gpus=None):
212
 
213
 
214
  def compute_pm(coords, labels, pm_method, max_gpus=None):
 
 
 
 
 
 
 
 
215
  ngpu = get_gpu_count(max_gpus)
216
 
217
  if ngpu == 0:
@@ -245,7 +280,6 @@ def compute_pm(coords, labels, pm_method, max_gpus=None):
245
  out[s] = float(np.clip(pm_score, 0.0, 1.0))
246
  return out
247
 
248
- # GPU version
249
  device = min(ngpu - 1, 1)
250
  device_str = f"cuda:{device}"
251
  coords_t = torch.tensor(coords, device=device_str)
@@ -287,7 +321,18 @@ def compute_pm(coords, labels, pm_method, max_gpus=None):
287
  def pm_ci_components_full(
288
  coords_d, coords_rest, eigvals, labels, *, delta=0.05, K=1.0, C1=1.0, C2=1.0
289
  ):
290
- """PM CI components exactly as original - complete implementation."""
 
 
 
 
 
 
 
 
 
 
 
291
  _EPS = 1e-12
292
 
293
  def _safe_x(a, theta):
@@ -387,7 +432,6 @@ def pm_ci_components_full(
387
 
388
  bias_ci[s] = max(abs(v - pm_center) for v in corner_vals)
389
 
390
- # Probabilistic half-width
391
  R_sq = float(mah_sq.max()) + 1e-12
392
  log_term = math.log(6.0 / delta)
393
  eps_mu = math.sqrt(2 * sigma2_g * log_term / n_p) + 3 * R_sq * log_term / n_p
@@ -419,7 +463,15 @@ def pm_ci_components_full(
419
 
420
 
421
  def ps_ci_components_full(coords_d, coords_rest, eigvals, labels, *, delta=0.05):
422
- """PS CI components exactly as original - complete implementation."""
 
 
 
 
 
 
 
 
423
 
424
  def _mean_dev(lam_max, delta, n_eff):
425
  return math.sqrt(2 * lam_max * math.log(2 / delta) / n_eff)
 
 
1
  import math
 
2
  import numpy as np
3
  import torch
4
  from scipy.special import gammaincc
5
  from scipy.stats import gamma
6
 
7
+ from config import COV_TOL
8
  from utils import get_gpu_count, mahalanobis_torch, safe_cov_torch
9
 
10
 
11
  def pm_tail_gamma(d_out_sq, sq_dists):
12
+ """
13
+ Computes the PM measure based on the Gamma fit.
14
+ :param d_out_sq: squared mahalanobis distance from the output to its cluster on the manifold.
15
+ :param sq_dists: squared mahalanobis distance of all distortions in the cluster to their cluster on the manifold.
16
+ :return: PM score.
17
+ """
18
  mu = sq_dists.mean().item()
19
  var = sq_dists.var(unbiased=True).item()
20
  if var == 0.0:
 
25
 
26
 
27
  def pm_tail_rank(d_out_sq, sq_dists):
28
+ """
29
+ A depracted method to compute the PM measure based on the ranking method of distances.
30
+ """
31
  rank = int((sq_dists < d_out_sq).sum().item())
32
  n = sq_dists.numel()
33
  return 1.0 - (rank + 0.5) / (n + 1.0)
 
48
  return_complement=False,
49
  return_cval=False,
50
  ):
51
+ """
52
+ Compute diffusion maps from a high dimensional set of points.
53
+
54
+ :param X_np: high dimensional input.
55
+ :param labels_by_mix: used to keep track of each source's coordinates on the manifold.
56
+ :param cutoff: the desired ratio between sum of kept and sum of all eigenvalues.
57
+ :param tol: deprecated since we do not use the "lobpcg" solver.
58
+ :param diffusion_time: number of steps taken on the probability transition matrix.
59
+ :param alpha: normalization factor in [0, 1].
60
+ :param eig_solver: "lobpcg" or "full".
61
+ :param k: pre-defined truncation dimension.
62
+ :param device: "cpu" or "cuda".
63
+ :param return_eigs: return eigenvalues and eigenvectors.
64
+ :param return_complement: return complementary coordinates, not just kept coordinates.
65
+ :param return_cval: calculate and return the psi_2 norm of the coordinates.
66
+ :return:
67
+ """
68
  device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")
69
  X = torch.as_tensor(X_np, dtype=torch.float32, device=device)
70
  N = X.shape[0]
 
162
 
163
 
164
  def compute_ps(coords, labels, max_gpus=None):
165
+ """
166
+ Computes the PS measure.
167
+ :param coords: coordinates on the manifold.
168
+ :param labels: assign source index per coordinate.
169
+ :param max_gpus: maximal number of GPUs to use.
170
+ :return: the PS measure.
171
+ """
172
  ngpu = get_gpu_count(max_gpus)
173
 
174
  if ngpu == 0:
 
199
  out[s] = (1 - A / (A + B_min + 1e-6)).item()
200
  return out
201
 
202
+ device = min(ngpu - 1, 1)
 
203
  device_str = f"cuda:{device}"
204
  coords_t = torch.tensor(coords, device=device_str)
205
  spks_here = sorted({l.split("-")[0] for l in labels})
 
239
 
240
 
241
  def compute_pm(coords, labels, pm_method, max_gpus=None):
242
+ """
243
+ Computes the PM measure.
244
+ :param coords: coordinates on the manifold.
245
+ :param labels: assign source index per coordinate.
246
+ :param pm_method: "rank" or "gamma".
247
+ :param max_gpus: maximal number of GPUs to use.
248
+ :return: the PS measure.
249
+ """
250
  ngpu = get_gpu_count(max_gpus)
251
 
252
  if ngpu == 0:
 
280
  out[s] = float(np.clip(pm_score, 0.0, 1.0))
281
  return out
282
 
 
283
  device = min(ngpu - 1, 1)
284
  device_str = f"cuda:{device}"
285
  coords_t = torch.tensor(coords, device=device_str)
 
321
  def pm_ci_components_full(
322
  coords_d, coords_rest, eigvals, labels, *, delta=0.05, K=1.0, C1=1.0, C2=1.0
323
  ):
324
+ """
325
+ Computes the error radius and tail bounds for the PM measure.
326
+ :param coords_d: Retained diffusion maps coordinates.
327
+ :param coords_rest: Complement diffusion maps coordinates.
328
+ :param eigvals: Eigenvalues of the diffusion maps.
329
+ :param labels: Assign source index per coordinate
330
+ :param delta: 1-\delta is the confidence score.
331
+ :param K: Absolute constant.
332
+ :param C1: Absolute constant.
333
+ :param C2: Absolute constant.
334
+ :return: error radius and tail bounds for the PM measure.
335
+ """
336
  _EPS = 1e-12
337
 
338
  def _safe_x(a, theta):
 
432
 
433
  bias_ci[s] = max(abs(v - pm_center) for v in corner_vals)
434
 
 
435
  R_sq = float(mah_sq.max()) + 1e-12
436
  log_term = math.log(6.0 / delta)
437
  eps_mu = math.sqrt(2 * sigma2_g * log_term / n_p) + 3 * R_sq * log_term / n_p
 
463
 
464
 
465
  def ps_ci_components_full(coords_d, coords_rest, eigvals, labels, *, delta=0.05):
466
+ """
467
+ Computes the error radius and tail bounds for the PS measure.
468
+ :param coords_d: Retained diffusion maps coordinates.
469
+ :param coords_rest: Complement diffusion maps coordinates.
470
+ :param eigvals: Eigenvalues of the diffusion maps.
471
+ :param labels: Assign source index per coordinate
472
+ :param delta: 1-\delta is the confidence score.
473
+ :return: error radius and tail bounds for the PS measure.
474
+ """
475
 
476
  def _mean_dev(lam_max, delta, n_eff):
477
  return math.sqrt(2 * lam_max * math.log(2 / delta) / n_eff)
models.py CHANGED
@@ -15,8 +15,10 @@ from config import BATCH_SIZE, ENERGY_HOP_MS, ENERGY_WIN_MS, SR
15
  from utils import get_gpu_count
16
 
17
 
18
- class BalancedDualGPUModel:
19
-
 
 
20
  def __init__(self, model_name, layer, max_gpus=None):
21
  self.layer = layer
22
  self.models = []
@@ -24,7 +26,7 @@ class BalancedDualGPUModel:
24
  self.devices = []
25
  ngpu = get_gpu_count(max_gpus)
26
 
27
- for gpu_id in range(min(ngpu, 2)):
28
  device = f"cuda:{gpu_id}"
29
  self.devices.append(device)
30
  ckpt, cls, _ = get_model_config(layer)[model_name]
@@ -90,7 +92,6 @@ class BalancedDualGPUModel:
90
  mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
91
  keep.append(hs[b][mask_t].cpu())
92
 
93
- # Aggressive cleanup
94
  del hs, input_values, inputs
95
  torch.cuda.empty_cache()
96
 
@@ -106,7 +107,6 @@ class BalancedDualGPUModel:
106
  except Exception as e:
107
  self.result_queue.put((task_id, e))
108
  finally:
109
- # Always clear cache after processing
110
  torch.cuda.empty_cache()
111
 
112
  def process_batch(self, signals, masks, use_mlm=False):
@@ -150,8 +150,12 @@ class BalancedDualGPUModel:
150
  self.cleanup()
151
 
152
 
153
- # NO CACHE - we need to clean up models properly between runs
154
  def get_model_config(layer):
 
 
 
 
 
155
  return {
156
  "raw": (None, None, None),
157
  "wavlm": ("microsoft/wavlm-large", WavLMModel, layer),
@@ -164,21 +168,25 @@ def get_model_config(layer):
164
  }
165
 
166
 
167
- # Store loaded models globally to properly manage them
168
  _loaded_models = {}
169
 
170
 
171
  def load_model(name, layer, max_gpus=None):
 
 
 
 
 
 
 
172
  global _loaded_models
173
 
174
- # Clean up any previously loaded models first
175
  if _loaded_models:
176
  for key, model_data in _loaded_models.items():
177
  if isinstance(model_data, tuple) and len(model_data) == 2:
178
- if isinstance(model_data[0], BalancedDualGPUModel):
179
  model_data[0].cleanup()
180
  elif isinstance(model_data[0], tuple):
181
- # Single GPU model
182
  _, model = model_data[0]
183
  del model
184
  _loaded_models.clear()
@@ -190,7 +198,7 @@ def load_model(name, layer, max_gpus=None):
190
 
191
  ngpu = get_gpu_count(max_gpus)
192
  if ngpu > 1:
193
- model = BalancedDualGPUModel(name, layer, max_gpus)
194
  _loaded_models[name] = (model, layer)
195
  return model, layer
196
  else:
@@ -219,15 +227,16 @@ def load_model(name, layer, max_gpus=None):
219
 
220
 
221
  def cleanup_all_models():
222
- """Call this at the end of each experiment to ensure complete cleanup"""
 
 
223
  global _loaded_models
224
  if _loaded_models:
225
  for key, model_data in _loaded_models.items():
226
  if isinstance(model_data, tuple) and len(model_data) == 2:
227
- if isinstance(model_data[0], BalancedDualGPUModel):
228
  model_data[0].cleanup()
229
  elif isinstance(model_data[0], tuple):
230
- # Single GPU model
231
  _, model = model_data[0]
232
  del model
233
  _loaded_models.clear()
@@ -236,6 +245,12 @@ def cleanup_all_models():
236
 
237
 
238
  def embed_batch_raw(signals, masks_audio):
 
 
 
 
 
 
239
  win = int(ENERGY_WIN_MS * SR / 1000)
240
  hop = int(ENERGY_HOP_MS * SR / 1000)
241
  reps, L_max = [], 0
@@ -253,6 +268,9 @@ def embed_batch_raw(signals, masks_audio):
253
  def embed_batch_single_gpu(
254
  signals, masks_audio, extractor, model, layer, use_mlm=False
255
  ):
 
 
 
256
  if not signals:
257
  return torch.empty(0, 0, 0)
258
  device = next(model.parameters()).device
@@ -281,7 +299,6 @@ def embed_batch_single_gpu(
281
  mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
282
  all_keeps.append(hs[b][mask_t].cpu())
283
 
284
- # Aggressive cleanup
285
  del hs, input_values, inputs
286
  torch.cuda.empty_cache()
287
 
@@ -289,7 +306,6 @@ def embed_batch_single_gpu(
289
  L_max = max(x.shape[0] for x in all_keeps)
290
  keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
291
  result = torch.stack(keep_padded, dim=0)
292
- # Clean up intermediate lists
293
  del all_keeps, keep_padded
294
  return result
295
  else:
@@ -297,9 +313,19 @@ def embed_batch_single_gpu(
297
 
298
 
299
  def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
 
 
 
 
 
 
 
 
 
 
300
  if model_wrapper == "raw":
301
  return embed_batch_raw(signals, masks_audio)
302
- if isinstance(model_wrapper, BalancedDualGPUModel):
303
  all_embeddings = []
304
  batch_size = min(BATCH_SIZE, 2)
305
  for i in range(0, len(signals), batch_size):
@@ -308,7 +334,6 @@ def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
308
  )
309
  if batch_emb.numel() > 0:
310
  all_embeddings.append(batch_emb)
311
- # Clear cache after each batch
312
  torch.cuda.empty_cache()
313
 
314
  if all_embeddings:
 
15
  from utils import get_gpu_count
16
 
17
 
18
+ class BalancedMultiGPUModel:
19
+ """
20
+ Distributes model inference workload across GPUs.
21
+ """
22
  def __init__(self, model_name, layer, max_gpus=None):
23
  self.layer = layer
24
  self.models = []
 
26
  self.devices = []
27
  ngpu = get_gpu_count(max_gpus)
28
 
29
+ for gpu_id in range(ngpu):
30
  device = f"cuda:{gpu_id}"
31
  self.devices.append(device)
32
  ckpt, cls, _ = get_model_config(layer)[model_name]
 
92
  mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
93
  keep.append(hs[b][mask_t].cpu())
94
 
 
95
  del hs, input_values, inputs
96
  torch.cuda.empty_cache()
97
 
 
107
  except Exception as e:
108
  self.result_queue.put((task_id, e))
109
  finally:
 
110
  torch.cuda.empty_cache()
111
 
112
  def process_batch(self, signals, masks, use_mlm=False):
 
150
  self.cleanup()
151
 
152
 
 
153
  def get_model_config(layer):
154
+ """
155
+ Get self-supervised model configuration.
156
+ :param layer: specific transformer layer to choose.
157
+ :return: Configuration.
158
+ """
159
  return {
160
  "raw": (None, None, None),
161
  "wavlm": ("microsoft/wavlm-large", WavLMModel, layer),
 
168
  }
169
 
170
 
 
171
  _loaded_models = {}
172
 
173
 
174
  def load_model(name, layer, max_gpus=None):
175
+ """
176
+ Load the chosen self-supervised model.
177
+ :param name: name of model.
178
+ :param layer: chosen layer.
179
+ :param max_gpus: maximal gpus to use.
180
+ :return: extractor, model, and layer.
181
+ """
182
  global _loaded_models
183
 
 
184
  if _loaded_models:
185
  for key, model_data in _loaded_models.items():
186
  if isinstance(model_data, tuple) and len(model_data) == 2:
187
+ if isinstance(model_data[0], BalancedMultiGPUModel):
188
  model_data[0].cleanup()
189
  elif isinstance(model_data[0], tuple):
 
190
  _, model = model_data[0]
191
  del model
192
  _loaded_models.clear()
 
198
 
199
  ngpu = get_gpu_count(max_gpus)
200
  if ngpu > 1:
201
+ model = BalancedMultiGPUModel(name, layer, max_gpus)
202
  _loaded_models[name] = (model, layer)
203
  return model, layer
204
  else:
 
227
 
228
 
229
  def cleanup_all_models():
230
+ """
231
+ Call this at the end of each experiment to ensure complete cleanup
232
+ """
233
  global _loaded_models
234
  if _loaded_models:
235
  for key, model_data in _loaded_models.items():
236
  if isinstance(model_data, tuple) and len(model_data) == 2:
237
+ if isinstance(model_data[0], BalancedMultiGPUModel):
238
  model_data[0].cleanup()
239
  elif isinstance(model_data[0], tuple):
 
240
  _, model = model_data[0]
241
  del model
242
  _loaded_models.clear()
 
245
 
246
 
247
  def embed_batch_raw(signals, masks_audio):
248
+ """
249
+ Waveform encoding in case it was chosen to skip self-supervised encording and push waveform directly to diffusion maps
250
+ :param signals: waveform signals.
251
+ :param masks_audio: voice activity masks of sources.
252
+ :return:
253
+ """
254
  win = int(ENERGY_WIN_MS * SR / 1000)
255
  hop = int(ENERGY_HOP_MS * SR / 1000)
256
  reps, L_max = [], 0
 
268
  def embed_batch_single_gpu(
269
  signals, masks_audio, extractor, model, layer, use_mlm=False
270
  ):
271
+ """
272
+ See embed_batch.
273
+ """
274
  if not signals:
275
  return torch.empty(0, 0, 0)
276
  device = next(model.parameters()).device
 
299
  mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
300
  all_keeps.append(hs[b][mask_t].cpu())
301
 
 
302
  del hs, input_values, inputs
303
  torch.cuda.empty_cache()
304
 
 
306
  L_max = max(x.shape[0] for x in all_keeps)
307
  keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
308
  result = torch.stack(keep_padded, dim=0)
 
309
  del all_keeps, keep_padded
310
  return result
311
  else:
 
313
 
314
 
315
  def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
316
+ """
317
+ Encode a batch of signals using the self-supervised model chosen.
318
+
319
+ :param signals: waveform signals to encode.
320
+ :param masks_audio: voice activity masks of sources.
321
+ :param model_wrapper: chosen model's wrapper.
322
+ :param layer: transformer layer.
323
+ :param use_mlm: deprecated.
324
+ :return: embedded signal representations by the model's layer.
325
+ """
326
  if model_wrapper == "raw":
327
  return embed_batch_raw(signals, masks_audio)
328
+ if isinstance(model_wrapper, BalancedMultiGPUModel):
329
  all_embeddings = []
330
  batch_size = min(BATCH_SIZE, 2)
331
  for i in range(0, len(signals), batch_size):
 
334
  )
335
  if batch_emb.numel() > 0:
336
  all_embeddings.append(batch_emb)
 
337
  torch.cuda.empty_cache()
338
 
339
  if all_embeddings:
utils.py CHANGED
@@ -3,18 +3,16 @@ import threading
3
  import warnings
4
  from dataclasses import dataclass
5
  from pathlib import Path
6
-
7
  import numpy as np
8
  import torch
9
- try:
10
- from scipy.optimize import linear_sum_assignment as _lsa
11
- except Exception:
12
- _lsa = None
13
-
14
  warnings.filterwarnings("ignore", message="Some weights of Wav2Vec2Model")
15
 
16
 
17
  def get_gpu_count(max_gpus=None):
 
 
 
 
18
  ngpu = torch.cuda.device_count()
19
  if max_gpus is not None:
20
  ngpu = min(ngpu, max_gpus)
@@ -22,7 +20,9 @@ def get_gpu_count(max_gpus=None):
22
 
23
 
24
  def clear_gpu_memory():
25
- """Enhanced GPU memory clearing"""
 
 
26
  if torch.cuda.is_available():
27
  for i in range(torch.cuda.device_count()):
28
  with torch.cuda.device(i):
@@ -33,11 +33,15 @@ def clear_gpu_memory():
33
 
34
 
35
  def get_gpu_memory_info(verbose=False):
 
 
 
 
36
  if not verbose:
37
  return
38
  for i in range(torch.cuda.device_count()):
39
  try:
40
- free_b, total_b = torch.cuda.mem_get_info(i) # type: ignore[attr-defined]
41
  free_gb = free_b / 1024**3
42
  total_gb = total_b / 1024**3
43
  except Exception:
@@ -47,60 +51,10 @@ def get_gpu_memory_info(verbose=False):
47
  print(f"GPU {i}: {mem_allocated:.2f}GB allocated, {free_gb:.2f}GB free / {total_gb:.2f}GB total")
48
 
49
 
50
- def write_wav_16bit(path, x, sr=16000):
51
- path = Path(path)
52
- path.parent.mkdir(parents=True, exist_ok=True)
53
- try:
54
- import soundfile as sf
55
-
56
- sf.write(str(path), x.astype(np.float32), sr)
57
- except Exception:
58
- from scipy.io.wavfile import write
59
-
60
- write(str(path), sr, (np.clip(x, -1, 1) * 32767).astype(np.int16))
61
-
62
-
63
- def safe_corr_np(a, b):
64
- L = min(len(a), len(b))
65
- if L <= 1:
66
- return 0.0
67
- a = a[:L].astype(np.float64)
68
- b = b[:L].astype(np.float64)
69
- a -= a.mean()
70
- b -= b.mean()
71
- da = a.std()
72
- db = b.std()
73
- if da <= 1e-12 or db <= 1e-12:
74
- return 0.0
75
- r = float((a * b).mean() / (da * db))
76
- return max(-1.0, min(1.0, r))
77
-
78
-
79
- def hungarian(cost):
80
- try:
81
- if _lsa is not None:
82
- return _lsa(cost)
83
- raise RuntimeError("scipy.optimize.linear_sum_assignment unavailable")
84
- except Exception:
85
- used = set()
86
- rows, cols = [], []
87
- for i in range(cost.shape[0]):
88
- j = int(
89
- np.argmin(
90
- [
91
- cost[i, k] if k not in used else 1e12
92
- for k in range(cost.shape[1])
93
- ]
94
- )
95
- )
96
- used.add(j)
97
- rows.append(i)
98
- cols.append(j)
99
- return np.asarray(rows), np.asarray(cols)
100
-
101
-
102
  class GPUWorkDistributor:
103
-
 
 
104
  def __init__(self, max_gpus=None):
105
  ngpu = get_gpu_count(max_gpus)
106
  self.gpu_locks = [threading.Lock() for _ in range(max(1, min(ngpu, 2)))]
@@ -121,7 +75,6 @@ class GPUWorkDistributor:
121
  with torch.cuda.device(gid):
122
  kwargs["device"] = f"cuda:{gid}"
123
  result = func(*args, **kwargs)
124
- # Clear cache after execution
125
  torch.cuda.empty_cache()
126
  return result
127
  finally:
@@ -189,35 +142,12 @@ def canonicalize_mixtures(mixtures, systems=None):
189
 
190
  raise ValueError("Unsupported 'mixtures' format.")
191
 
192
-
193
- def random_misalign(sig, sr, max_ms, mode="single", rng=None):
194
- import random
195
-
196
- if rng is None:
197
- rng = random
198
- max_samples = int(sr * max_ms / 1000)
199
- if max_samples == 0:
200
- return sig
201
- shift = (
202
- rng.randint(-max_samples, max_samples) if mode == "range" else int(max_samples)
203
- )
204
- if shift == 0:
205
- return sig
206
- if isinstance(sig, torch.Tensor):
207
- z = torch.zeros(abs(shift), dtype=sig.dtype, device=sig.device)
208
- return (
209
- torch.cat([z, sig[:-shift]]) if shift > 0 else torch.cat([sig[-shift:], z])
210
- )
211
- else:
212
- z = np.zeros(abs(shift), dtype=sig.dtype)
213
- return (
214
- np.concatenate([z, sig[:-shift]])
215
- if shift > 0
216
- else np.concatenate([sig[-shift:], z])
217
- )
218
-
219
-
220
  def safe_cov_torch(X):
 
 
 
 
 
221
  Xc = X - X.mean(dim=0, keepdim=True)
222
  cov = Xc.T @ Xc / (Xc.shape[0] - 1)
223
  if torch.linalg.matrix_rank(cov) < cov.shape[0]:
@@ -226,6 +156,13 @@ def safe_cov_torch(X):
226
 
227
 
228
  def mahalanobis_torch(x, mu, inv):
 
 
 
 
 
 
 
229
  diff = x - mu
230
  diff_T = diff.transpose(-1, -2) if diff.ndim >= 2 else diff
231
  return torch.sqrt(diff @ inv @ diff_T + 1e-6)
 
3
  import warnings
4
  from dataclasses import dataclass
5
  from pathlib import Path
 
6
  import numpy as np
7
  import torch
 
 
 
 
 
8
  warnings.filterwarnings("ignore", message="Some weights of Wav2Vec2Model")
9
 
10
 
11
  def get_gpu_count(max_gpus=None):
12
+ """
13
+ Get the number of available GPUs.
14
+ :param max_gpus: maximal number of GPUs to utilize.
15
+ """
16
  ngpu = torch.cuda.device_count()
17
  if max_gpus is not None:
18
  ngpu = min(ngpu, max_gpus)
 
20
 
21
 
22
  def clear_gpu_memory():
23
+ """
24
+ Enhanced GPU memory clearing
25
+ """
26
  if torch.cuda.is_available():
27
  for i in range(torch.cuda.device_count()):
28
  with torch.cuda.device(i):
 
33
 
34
 
35
  def get_gpu_memory_info(verbose=False):
36
+ """
37
+ Get GPU memory info.
38
+ :param verbose: if True, get info.
39
+ """
40
  if not verbose:
41
  return
42
  for i in range(torch.cuda.device_count()):
43
  try:
44
+ free_b, total_b = torch.cuda.mem_get_info(i)
45
  free_gb = free_b / 1024**3
46
  total_gb = total_b / 1024**3
47
  except Exception:
 
51
  print(f"GPU {i}: {mem_allocated:.2f}GB allocated, {free_gb:.2f}GB free / {total_gb:.2f}GB total")
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  class GPUWorkDistributor:
55
+ """
56
+ Distribute GPU memory into multiple GPUs.
57
+ """
58
  def __init__(self, max_gpus=None):
59
  ngpu = get_gpu_count(max_gpus)
60
  self.gpu_locks = [threading.Lock() for _ in range(max(1, min(ngpu, 2)))]
 
75
  with torch.cuda.device(gid):
76
  kwargs["device"] = f"cuda:{gid}"
77
  result = func(*args, **kwargs)
 
78
  torch.cuda.empty_cache()
79
  return result
80
  finally:
 
142
 
143
  raise ValueError("Unsupported 'mixtures' format.")
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def safe_cov_torch(X):
146
+ """
147
+ Compute the covariance matrix of X.
148
+ :param X: array to compute covariance matrix of.
149
+ :return: regularized covariance matrix.
150
+ """
151
  Xc = X - X.mean(dim=0, keepdim=True)
152
  cov = Xc.T @ Xc / (Xc.shape[0] - 1)
153
  if torch.linalg.matrix_rank(cov) < cov.shape[0]:
 
156
 
157
 
158
  def mahalanobis_torch(x, mu, inv):
159
+ """
160
+ Compute the mahalanobis distance with x centered around mu with inverse covariance matrix inv.
161
+ :param x: point to calculates distance from.
162
+ :param mu: x is centered around mu.
163
+ :param inv: the inverse covariance matrix.
164
+ :return: Mahalanobis distance.
165
+ """
166
  diff = x - mu
167
  diff_T = diff.transpose(-1, -2) if diff.ndim >= 2 else diff
168
  return torch.sqrt(diff @ inv @ diff_T + 1e-6)