Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| import warnings | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| from typing import List | |
| from typing import Tuple | |
| from typing import Union | |
| import librosa | |
| import torch | |
| import numpy as np | |
| from audiotools import AudioSignal | |
| logging.basicConfig(level=logging.INFO) | |
| ################### | |
| # beat sync utils # | |
| ################### | |
| AGGREGATOR_REGISTRY = { | |
| "mean": np.mean, | |
| "median": np.median, | |
| "max": np.max, | |
| "min": np.min, | |
| } | |
| def list_aggregators() -> list: | |
| return list(AGGREGATOR_REGISTRY.keys()) | |
| class TimeSegment: | |
| start: float | |
| end: float | |
| def duration(self): | |
| return self.end - self.start | |
| def __str__(self) -> str: | |
| return f"{self.start} - {self.end}" | |
| def find_overlapping_segment( | |
| self, segments: List["TimeSegment"] | |
| ) -> Union["TimeSegment", None]: | |
| """Find the first segment that overlaps with this segment, or None if no segment overlaps""" | |
| for s in segments: | |
| if s.start <= self.start and s.end >= self.end: | |
| return s | |
| return None | |
| def mkdir(path: Union[Path, str]) -> Path: | |
| p = Path(path) | |
| p.mkdir(parents=True, exist_ok=True) | |
| return p | |
| ################### | |
| # beat data # | |
| ################### | |
| class BeatSegment(TimeSegment): | |
| downbeat: bool = False # if there's a downbeat on the start_time | |
| class Beats: | |
| def __init__(self, beat_times, downbeat_times): | |
| if isinstance(beat_times, np.ndarray): | |
| beat_times = beat_times.tolist() | |
| if isinstance(downbeat_times, np.ndarray): | |
| downbeat_times = downbeat_times.tolist() | |
| self._beat_times = beat_times | |
| self._downbeat_times = downbeat_times | |
| self._use_downbeats = False | |
| def use_downbeats(self, use_downbeats: bool = True): | |
| """use downbeats instead of beats when calling beat_times""" | |
| self._use_downbeats = use_downbeats | |
| def beat_segments(self, signal: AudioSignal) -> List[BeatSegment]: | |
| """ | |
| segments a song into time segments corresponding to beats. | |
| the first segment starts at 0 and ends at the first beat time. | |
| the last segment starts at the last beat time and ends at the end of the song. | |
| """ | |
| beat_times = self._beat_times.copy() | |
| downbeat_times = self._downbeat_times | |
| beat_times.insert(0, 0) | |
| beat_times.append(signal.signal_duration) | |
| downbeat_ids = np.intersect1d(beat_times, downbeat_times, return_indices=True)[ | |
| 1 | |
| ] | |
| is_downbeat = [ | |
| True if i in downbeat_ids else False for i in range(len(beat_times)) | |
| ] | |
| segments = [ | |
| BeatSegment(start_time, end_time, downbeat) | |
| for start_time, end_time, downbeat in zip( | |
| beat_times[:-1], beat_times[1:], is_downbeat | |
| ) | |
| ] | |
| return segments | |
| def get_beats(self) -> np.ndarray: | |
| """returns an array of beat times, in seconds | |
| if downbeats is True, returns an array of downbeat times, in seconds | |
| """ | |
| return np.array( | |
| self._downbeat_times if self._use_downbeats else self._beat_times | |
| ) | |
| def beat_times(self) -> np.ndarray: | |
| """return beat times""" | |
| return np.array(self._beat_times) | |
| def downbeat_times(self) -> np.ndarray: | |
| """return downbeat times""" | |
| return np.array(self._downbeat_times) | |
| def beat_times_to_feature_frames( | |
| self, signal: AudioSignal, features: np.ndarray | |
| ) -> np.ndarray: | |
| """convert beat times to frames, given an array of time-varying features""" | |
| beat_times = self.get_beats() | |
| beat_frames = ( | |
| beat_times * signal.sample_rate / signal.signal_length * features.shape[-1] | |
| ).astype(np.int64) | |
| return beat_frames | |
| def sync_features( | |
| self, feature_frames: np.ndarray, features: np.ndarray, aggregate="median" | |
| ) -> np.ndarray: | |
| """sync features to beats""" | |
| if aggregate not in AGGREGATOR_REGISTRY: | |
| raise ValueError(f"unknown aggregation method {aggregate}") | |
| return librosa.util.sync( | |
| features, feature_frames, aggregate=AGGREGATOR_REGISTRY[aggregate] | |
| ) | |
| def to_json(self) -> dict: | |
| """return beats and downbeats as json""" | |
| return { | |
| "beats": self._beat_times, | |
| "downbeats": self._downbeat_times, | |
| "use_downbeats": self._use_downbeats, | |
| } | |
| def from_dict(cls, data: dict): | |
| """load beats and downbeats from json""" | |
| inst = cls(data["beats"], data["downbeats"]) | |
| inst.use_downbeats(data["use_downbeats"]) | |
| return inst | |
| def save(self, output_dir: Path): | |
| """save beats and downbeats to json""" | |
| mkdir(output_dir) | |
| with open(output_dir / "beats.json", "w") as f: | |
| json.dump(self.to_json(), f) | |
| def load(cls, input_dir: Path): | |
| """load beats and downbeats from json""" | |
| beats_file = Path(input_dir) / "beats.json" | |
| with open(beats_file, "r") as f: | |
| data = json.load(f) | |
| return cls.from_dict(data) | |
| ################### | |
| # beat tracking # | |
| ################### | |
| class BeatTracker: | |
| def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]: | |
| """extract beats from an audio signal""" | |
| raise NotImplementedError | |
| def __call__(self, signal: AudioSignal) -> Beats: | |
| """extract beats from an audio signal | |
| NOTE: if the first beat (and/or downbeat) is detected within the first 100ms of the audio, | |
| it is discarded. This is to avoid empty bins with no beat synced features in the first beat. | |
| Args: | |
| signal (AudioSignal): signal to beat track | |
| Returns: | |
| Tuple[np.ndarray, np.ndarray]: beats and downbeats | |
| """ | |
| beats, downbeats = self.extract_beats(signal) | |
| return Beats(beats, downbeats) | |
| class WaveBeat(BeatTracker): | |
| def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"): | |
| from wavebeat.dstcn import dsTCNModel | |
| model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device)) | |
| model.eval() | |
| self.device = device | |
| self.model = model | |
| def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]: | |
| """returns beat and downbeat times, in seconds""" | |
| # extract beats | |
| beats, downbeats = self.model.predict_beats_from_array( | |
| audio=signal.audio_data.squeeze(0), | |
| sr=signal.sample_rate, | |
| use_gpu=self.device != "cpu", | |
| ) | |
| return beats, downbeats | |
| class MadmomBeats(BeatTracker): | |
| def __init__(self): | |
| raise NotImplementedError | |
| def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]: | |
| """returns beat and downbeat times, in seconds""" | |
| pass | |
| BEAT_TRACKER_REGISTRY = { | |
| "wavebeat": WaveBeat, | |
| "madmom": MadmomBeats, | |
| } | |
| def list_beat_trackers() -> list: | |
| return list(BEAT_TRACKER_REGISTRY.keys()) | |
| def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker: | |
| if beat_tracker not in BEAT_TRACKER_REGISTRY: | |
| raise ValueError( | |
| f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}" | |
| ) | |
| return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs) |